In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np
import shutil
import os
from kaggle_datasets import KaggleDatasets
import sys
import random
import cv2

In [None]:
classes = [
    'complex', 
    'frog_eye_leaf_spot', 
    'powdery_mildew', 
    'rust', 
    'scab',
    'healthy']

img_size = 720

In [None]:
df = pd.read_csv('../input/plant-pathology-2021-fgvc8/train.csv', index_col='image')
init_len = len(df)

original_labels = df['labels'].values.copy()

df['labels'] = [x.split(' ') for x in df['labels']]
labels = MultiLabelBinarizer(classes=classes).fit_transform(df['labels'].values)

df = pd.DataFrame(columns=classes, data=labels, index=df.index)

df.to_csv('train.csv')
display(df.head())

In [None]:
def serialize_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [img_size, img_size])
    image = tf.cast(image, tf.uint8)
    return tf.image.encode_jpeg(image).numpy()

def serialize_sample(image, image_name, label):
    feature = {
        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
        'image_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_name])),
        'complex': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[0]])),
        'frog_eye_leaf_spot': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[1]])),
        'powdery_mildew': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[2]])),
        'rust': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[3]])),
        'scab': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[4]])),
        'healthy': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[5]]))}
    sample = tf.train.Example(features=tf.train.Features(feature=feature))
    return sample.SerializeToString()

In [None]:
samples = []
for image_name, labels in tqdm(df.iterrows()):
    path = os.path.join('../input/plant-pathology-2021-fgvc8/train_images', image_name)
    image = serialize_image(path)
    samples.append(serialize_sample(image, image_name.encode(), labels))
random.shuffle(samples)

train_size = int(0.8*len(samples))

with tf.io.TFRecordWriter('train.tfrec') as writer:
    [writer.write(x) for x in samples[:train_size]]
    
with tf.io.TFRecordWriter('test.tfrec') as writer:
    [writer.write(x) for x in samples[train_size:]]