# Plant Pathology TFRecords Maker

Version 0: Implement all functionality except MixUp   
- Image size 224   
- Random hue,sat,contrast,brightness,flips   
- Feature dict :
    ```
    {
        'image' : _bytes_feature(img.numpy().tobytes()),  # float32 images
        'target': _float_arr_feature(labels),
        'image_name': _bytes_feature(bytes(image_name,encoding='utf8'))
    }
    ```

Version 1:
- Image size 512
- Random hue,sat,contrast,brightness,flips
- Feature dict :
    ```
    {
        'image' : _bytes_feature(img.numpy().tobytes()), #uint8 images
        'target': _float_arr_feature(labels),
        'image_name': _bytes_feature(bytes(image_name,encoding='utf8'))
    }
    ```


### Import necessary packages

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from sklearn.model_selection import train_test_split,StratifiedKFold
import gc
import cv2

### Get main training csv

In [None]:
df = pd.read_csv('../input/plant-pathology-2021-fgvc8/train.csv')
df_images = df['image']
df_labels = df['labels']


### Multiple entries one-hot function

In [None]:
def custom_one_hot(label_str):
    all_labels = ['healthy', 'scab', 'frog_eye_leaf_spot', 'rust', 'powdery_mildew','complex']
    this_labels = label_str.split()
    
    retarr = np.zeros((6,),dtype=np.float32)
    for i in this_labels:
        retarr[all_labels.index(i)] = 1.0
    return retarr


In [None]:
# all_labels = ['healthy']
# for i in range(len(df_labels)):
#     label = df_labels[i]
#     if len(label.split())!=1:
#         labels = label.split()
#         for j in labels:
#             if j not in all_labels:
#                 all_labels.append(j)



# #all_labels = ['healthy', 'scab', 'frog_eye_leaf_spot',  'rust', 'powdery_mildew']

# for i in range(len(df_labels)):
#     df_labels[i] = custom_one_hot(df_labels[i])

### Declaring image size

In [None]:
IMG_SIZE = 512

### Make 6 Stratified K Folds

In [None]:
skf = StratifiedKFold(n_splits=6,shuffle=True)

FOLDS_LIST=[]
a = 0
for train_index,test_index in skf.split(df_images,df_labels):
    #print("TRAIN:", train_index, "TEST:", test_index)
    df_images_train, df_images_test = df_images[train_index], df_images[test_index]
    df_labels_train, df_labels_test = df_labels[train_index], df_labels[test_index]
    df_train = pd.concat([df_images_train,df_labels_train],axis=1)
    df_test = pd.concat([df_images_test,df_labels_test],axis=1)
    
    df_test.to_csv('fold_'+str(a)+'.csv')
    a+=1
    
    FOLDS_LIST.append(df_test)

### Declaring TFRecords utility functions

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int32_list=tf.train.Int64List(value=[value]))

def _float_arr_feature(arr):
    return tf.train.Feature(float_list=tf.train.FloatList(value=arr))

### Augmentation functions

In [None]:
def read_img(image_name):
    filepath = '../input/plant-pathology-2021-fgvc8/train_images/'+image_name
    image = tf.io.decode_jpeg(tf.io.read_file(filepath))
    
    return image

def get_augment_list():
    return np.array(list(map(lambda x:x<1,np.random.randint(2, size=6))),dtype='bool')


@tf.function
def resize_image(image):
    return tf.cast(tf.image.resize(image,[IMG_SIZE,IMG_SIZE]),tf.float32)

    
    
@tf.function
def augment_img_randomly(img):
    '''
    Augmentaions to be used: (use stateless versions of these)
    
    Random hue (0.2)
    Random brightness (0.3)
    Random saturation (0.7,1.3)
    Random contrast  (0.8,1.2)
    ''' 
    augment_list = get_augment_list()
    image = resize_image(img)
     #(32,512,512,3)
    
    if augment_list[0]:
        image = tf.image.random_saturation(image,0.7,1.3)
    if augment_list[1]:
        image = tf.image.random_contrast(image,0.8,1.2)
    if augment_list[2]:
        image = tf.image.random_brightness(image,0.3)
    if augment_list[3]:
        image = tf.image.random_hue(image,0.2)
    if augment_list[4]:
        image = tf.image.random_flip_left_right(image)
    if augment_list[5]:
        image = tf.image.random_flip_up_down(image)
    
    
    
    image = tf.cast(image,tf.uint8)
    del augment_list,img
    gc.collect()
        
    return image

### Make example

In [None]:

def example_generator(fold_num,image_name):
    img = read_img(image_name)
    img = augment_img_randomly(img)
    fold_df = FOLDS_LIST[fold_num]
    labels = fold_df[fold_df['image']==image_name].values[0,1]
    labels = custom_one_hot(labels)
    feature = {
        'image' :  _bytes_feature(img.numpy().tobytes()),
        'target' : _float_arr_feature(labels),
        'image_name' : _bytes_feature(bytes(image_name,encoding='utf8'))
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
# fold_df = FOLDS_LIST[3]
# print(fold_df.head())
# image_names = fold_df['image']
# for image_name in image_names:
#     labels = fold_df[fold_df['image']==image_name].values
#     #print(custom_one_hot(labels[0,1]))
#     #print(example_generator(3,labels[0,0]))
#     break

### Write TFRecords

In [None]:
for i in range(len(FOLDS_LIST)):
    record_file = 'fold_'+str(i)+'.tfrecords'
    
    print('Writing ',record_file)
    
    image_names = list(FOLDS_LIST[i]['image'])
    
    fold_df = FOLDS_LIST[i]
    
    a=1
    num_files = len(list(image_names))
    
    with tf.io.TFRecordWriter(record_file) as writer:
      for k in image_names:
        
        print('Writing image ',a,' of ',num_files)
        proto_example = example_generator(i,k)
        writer.write(proto_example)
        del proto_example
        gc.collect()
        a+=1
    del writer
    gc.collect()
    