In [None]:
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm.auto import tqdm
import cv2
import matplotlib.pyplot as plt
from sklearn.preprocessing import MultiLabelBinarizer

In [None]:
TRAIN = os.listdir('../input/plant-pathology-2021-fgvc8/train_images')
print("Training Images = ",len(TRAIN))

In [None]:
df = pd.read_csv('../input/plant-pathology-2021-fgvc8/train.csv')
df

In [None]:
L = []
for i in range(len(df)):
    labels = df['labels'].iloc[i]
    lst = [j for j in labels.split(' ')]
    L.append(lst)
df['new_labels'] = L

In [None]:
df

In [None]:
s = list(df['new_labels'])

mlb = MultiLabelBinarizer()

converted_df = pd.DataFrame(mlb.fit_transform(s),columns=mlb.classes_, index=df.index)

In [None]:
converted_df['image'] = df['image']

In [None]:
converted_df

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

In [None]:
def serialize_example(feature, name, target):
    feature = {
      'image': _bytes_feature(feature),
      'name': _bytes_feature(name),
      'label': _int64_feature(target),

    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
merged = converted_df.sample(frac=1).reset_index(drop=True)
NUM_TFRECORDS =  68
IMGS_PER_FILE = len(converted_df)//NUM_TFRECORDS
for idx in tqdm(range(NUM_TFRECORDS)):
    with tf.io.TFRecordWriter(f'File_{idx+1}_{IMGS_PER_FILE}.tfrec') as writer:
        for k in tqdm(range(IMGS_PER_FILE*idx, IMGS_PER_FILE*(idx+1))):
            path ='../input/plant-pathology-2021-fgvc8/train_images/'
            img = cv2.imread(path+merged['image'].iloc[k])
            img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
            img = cv2.resize(img,(512,512))
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 94))[1].tostring()
            name = merged['image'].iloc[k].split('.')[0]
            target = list(merged.iloc[k])[:-1]
            example = serialize_example(
              img,
              str.encode(name),
              target
            )
            writer.write(example)

In [None]:
IMAGE_SIZE= [512,512]; BATCH_SIZE = 16
AUTO = tf.data.experimental.AUTOTUNE
TRAINING_FILENAMES = tf.io.gfile.glob('./*.tfrec')

In [None]:
TRAINING_FILENAMES

In [None]:
import math
np.set_printoptions(threshold=15, linewidth=80)
CLASSES = [0,1]

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    labels = [str(i) for i in  numpy_labels]
    # decoder = np.vectorize(lambda x: x.decode('UTF-8'))
    # numpy_labels = decoder(numpy_labels)
    # numpy_images = numpy_images[:,:,::-1]
    #if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
    #    numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = label
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        # image = cv2.imdecode(image,cv2.IMREA)
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    channels = tf.unstack(image, axis=-1)
    image    = tf.stack([channels[2], channels[1], channels[0]], axis=-1)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "name":tf.io.FixedLenFeature([],tf.string),
        "label": tf.io.FixedLenFeature([6], tf.int64,default_value=[0,0,0,0,0,0]),

    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = (example['label'])
    return image, label # returns a dataset of (image, label) pairs

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

In [None]:
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(16)
train_batch = iter(training_dataset)

display_batch_of_images(next(train_batch))