In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import PIL
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
from tqdm import tqdm
import tensorflow_addons as tfa
import random
from sklearn.preprocessing import MultiLabelBinarizer

pd.set_option("display.max_columns", None)


In [None]:
train = pd.read_csv('../input/plant-pathology-2021-fgvc8/train.csv')
train.head()

In [None]:
train['labels'] = train['labels'].apply(lambda string: string.split(' '))
train

In [None]:
s = list(train['labels'])
mlb = MultiLabelBinarizer()
trainx = pd.DataFrame(mlb.fit_transform(s), columns=mlb.classes_, index=train.index)

In [None]:
datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1/255.0,
                                                        rotation_range=2,
                                                        width_shift_range=0.1,
                                                        height_shift_range=0.1,
                                                    )

train_data = datagen.flow_from_dataframe(
    train,
    directory='../input/resized-plant2021/img_sz_512',
    x_col="image",
    y_col= 'labels',
    color_mode="rgb",
    target_size = (224,224),
    class_mode="categorical",
    batch_size=32,
    shuffle=True,
    seed=40,
)

# HEIGHT = 224
# WIDTH=224
# SEED = 40
# BATCH_SIZE=32


# train_ds = datagen.flow_from_dataframe(
#     train,
#     directory = '../input/resized-plant2021/img_sz_256',
#     subset='training',
#     x_col='image',
#     y_col='labels',
#     target_size=(HEIGHT,WIDTH),
#     color_mode='rgb',
#     class_mode='categorical',
#     batch_size=BATCH_SIZE,
#     shuffle=True,
#     seed=SEED
# )

In [None]:
seed = 1200
tf.random.set_seed(seed)

weights_path = '../input/keras-pretrained-models/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5'
model = keras.applications.InceptionResNetV2(weights=weights_path, include_top=False, input_shape=(224, 224, 3))

print(model.input)
print(model.output)

In [None]:
new_model = tf.keras.Sequential([
    model,
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(6, 
        kernel_initializer=keras.initializers.RandomUniform(seed=seed),
        bias_initializer=keras.initializers.Zeros(), name='dense_top', activation='sigmoid')
])

for layer in new_model.layers[:-1]:
    layer.trainable=False
    
new_model.summary()

In [None]:
f1 = tfa.metrics.F1Score(num_classes=6, average='macro')

callbacks = keras.callbacks.EarlyStopping(monitor=f1, patience=3, mode='max', restore_best_weights=True)


new_model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam(lr=1e-4), 
              metrics= [f1])

new_model.fit(train_data, epochs=30, callbacks=callbacks)

In [None]:
test = pd.read_csv('../input/plant-pathology-2021-fgvc8/sample_submission.csv')

for img_name in tqdm(test['image']):
    path = '../input/plant-pathology-2021-fgvc8/test_images/'+str(img_name)
    with PIL.Image.open(path) as img:
        img = img.resize((256,256))
        img.save(f'./{img_name}')

In [None]:
test_data = datagen.flow_from_dataframe(
    test,
    directory = './',
    x_col="image",
    y_col= None,
    color_mode="rgb",
    target_size = (256,256),
    classes=None,
    class_mode=None,
    batch_size=32,
    shuffle=False,
    seed=40,
)

preds = new_model.predict(test_data)
print(preds)
preds = preds.tolist()

indices = []
for pred in preds:
    temp = []
    for category in pred:
        if category>=0.3:
            temp.append(pred.index(category))
    if temp!=[]:
        indices.append(temp)
    else:
        temp.append(np.argmax(pred))
        indices.append(temp)
    
print(indices)


In [None]:
labels = (train_data.class_indices)
labels = dict((v,k) for k,v in labels.items())
print(labels)

testlabels = []


for image in indices:
    temp = []
    for i in image:
        temp.append(str(labels[i]))
    testlabels.append(' '.join(temp))

print(testlabels)

In [None]:
delfiles = tf.io.gfile.glob('./*.jpg')

for file in delfiles:
    os.remove(file)

In [None]:
sub = pd.read_csv('../input/plant-pathology-2021-fgvc8/sample_submission.csv')
sub['labels'] = testlabels
sub.to_csv('submission.csv', index=False)
sub