In [None]:
!pip3 install --no-dependencies ../input/efficientnetcassava/Keras_Applications-1.0.8-py3-none-any.whl
!pip3 install --no-dependencies ../input/efficientnetcassava/efficientnet-1.1.1-py3-none-any.whl

In [None]:
import re
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import json

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import losses
from sklearn.model_selection import train_test_split
from efficientnet.keras import EfficientNetB3 as EfficientNet

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
SIZE = 600
ORIGINAL_WIDTH = 800
ORIGINAL_HEIGHT = 600
CHANNELS = 3
BATCH_SIZE = 32

In [None]:
def decode_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 255.0)
    image = tf.image.resize(image, [ORIGINAL_HEIGHT, ORIGINAL_WIDTH])
    image = tf.reshape(image, [ORIGINAL_HEIGHT, ORIGINAL_WIDTH , CHANNELS])
    return image

def normalize(x):
    x = tf.image.resize(x, [ORIGINAL_HEIGHT, ORIGINAL_WIDTH])
    x = tf.reshape(x, [ORIGINAL_HEIGHT, ORIGINAL_WIDTH, CHANNELS])
    return x

In [None]:
def data_aug(x: tf.Tensor) -> tf.Tensor:
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.2, lambda: tf.image.random_crop(x, [int(ORIGINAL_HEIGHT*0.8), int(ORIGINAL_WIDTH*0.8), 3]), lambda: x)
    x = normalize(x)
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: tf.image.random_flip_left_right(x), lambda: x)
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: tf.image.random_flip_up_down(x), lambda: x)
    
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.7, lambda: tf.image.random_saturation(x, 0.6, 1.6), lambda: x)
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.7, lambda: tf.image.random_brightness(x, 0.05), lambda: x)
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.7, lambda: tf.image.random_contrast(x, 0.7, 1.3), lambda: x)
    x = tf.cond(tf.random.uniform([], 0, 1) > 0.5, lambda: tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)), lambda: x)
    return x

In [None]:
load_dir = "/kaggle/input/cassava-leaf-disease-classification"
sub_df = pd.read_csv(load_dir + '/sample_submission.csv')
sub_df['paths'] = load_dir + "/test_images/" + sub_df.image_id

In [None]:
def load_dataset(augment=False):
    test_dataset =tf.data.Dataset.from_tensor_slices(sub_df.paths.values).map(decode_image, num_parallel_calls=AUTO)
    
    if augment:
        test_dataset = test_dataset.map(lambda x: data_aug(x), num_parallel_calls=AUTO)
    else:
        test_dataset = test_dataset.map(lambda x:normalize(x))
    return test_dataset.batch(BATCH_SIZE).prefetch(AUTO)

In [None]:
def load_model(i):
    inputs = layers.Input(shape=(ORIGINAL_HEIGHT, ORIGINAL_WIDTH, 3))
    model = Sequential([
        EfficientNet(include_top=False,weights=None, input_tensor=inputs),
        layers.GlobalAveragePooling2D(name="avg_pool"),
        layers.BatchNormalization(),
        layers.Dropout(0.3, name="top_dropout"),
        layers.Dense(5, activation="softmax", name="pred")
    ])
    model.load_weights(F"../input/efficientnetcassava/EfficientNetB3_tl_best_weights_{i}.h5")
    model.compile(loss=losses.SparseCategoricalCrossentropy(), optimizer=tf.optimizers.Adam(lr=0.001), metrics=['accuracy'])
    return model

In [None]:
n_models = 5
models = []
for i in range(n_models):
    models.append(load_model(i))

In [None]:
preds = []
test_dataset = load_dataset()

for i in range(n_models):
    preds.append(models[i].predict(test_dataset, verbose=1))
    
for i in range(10):
    test_dataset_augmented = load_dataset(augment=True)
    for i in range(n_models):
        preds.append(models[i].predict(test_dataset_augmented, verbose=1))
    
        
preds = np.mean(preds, axis=0)
preds

In [None]:
sub_df['label'] = preds.argmax(axis=1)
sub_df.drop(columns='paths').to_csv('submission.csv', index=False)
!head submission.csv