In [None]:
!/opt/conda/bin/python3.7 -m pip install --upgrade pip
! pip install -q efficientnet

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import os
import re
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from kaggle_datasets import KaggleDatasets
import efficientnet.tfkeras as efn
import json
import csv
tf.__version__

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
GCS_DS_PATH

In [None]:
f = open("../input/cassava-leaf-disease-classification/label_num_to_disease_map.json","r")
disease_names = json.load(f)
disease_names

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
IMAGE_SIZE = [300,300]
AUTO = tf.data.experimental.AUTOTUNE
CLASSES = 5
CHANNELS=3

FILENAMES = tf.io.gfile.glob(GCS_DS_PATH + '/train_tfrecords/*.tfrec')
TRAINING_FILENAMES = FILENAMES[:-2]
VALIDATION_FILENAMES = FILENAMES[-2:]
TEST_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH + '/test_tfrecords/*.tfrec')

In [None]:
count_dict={}
with open("../input/cassava-leaf-disease-classification/train.csv") as file:
    reader = csv.reader(file,skipinitialspace=True)
    next(reader)
    for row in reader:
        count_dict[row[1]]=count_dict.get(row[1],0)+1
        
samples_ls = list(count_dict.values())
tot_samples = sum(samples_ls)
class_weights={}
for i in range(5):
    class_weights[i] = tot_samples/(CLASSES*samples_ls[i])
    
class_weights

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0 
    image = tf.image.resize(image, [*IMAGE_SIZE])
    return image


def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image"              : tf.io.FixedLenFeature([], tf.string),
        "target"              : tf.io.FixedLenFeature([], tf.int64),
    }
    
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = example["target"]
    label = tf.one_hot(label,depth=5)
    label = tf.cast(label,tf.float32)
    return image, label


def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image"              : tf.io.FixedLenFeature([], tf.string),
        "image_name"           : tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['image_name']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False 
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    return dataset

In [None]:
def data_augment(image, label):
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
            
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3)
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) 
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1)
        
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.9, upper=1.1)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.7)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.8)
        else:
            image = tf.image.central_crop(image, central_fraction=.9)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(IMAGE_SIZE[0]*.8),IMAGE_SIZE[0], dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
    
    image = tf.image.resize(image, [*IMAGE_SIZE])
    return image,label

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() 
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

In [None]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

train_ds = get_training_dataset()
valid_ds = get_validation_dataset()
test_ds = get_test_dataset()

print("Training:", train_ds)
print("Validation:",valid_ds)
print("Test:", test_ds)

In [None]:
print("Training data shapes:")
for image, label in train_ds.take(3):
    print(image.numpy().shape, label.numpy().shape)

In [None]:
print("Test data shapes:")
for image, idnum in test_ds.take(3):
    print(image.numpy().shape, idnum.numpy().shape)

In [None]:
'''2 : efn.EfficientNetB2(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        3 : tf.keras.applications.EfficientNetB3(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        4 : tf.keras.applications.EfficientNetB4(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        5 : tf.keras.applications.EfficientNetB5(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        6 : tf.keras.applications.EfficientNetB6(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        7 : tf.keras.applications.EfficientNetB7(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        8 : tf.keras.applications.Xception(include_top=False, weights='imagenet',input_shape=[*IMAGE_SIZE, 3]),
        9 : tf.keras.applications.ResNet50(include_top=False, weights='imagenet',input_shape=[*IMAGE_SIZE, 3]),
        10: tf.keras.applications.ResNet101(include_top=False, weights='imagenet',input_shape=[*IMAGE_SIZE, 3]),
        11: tf.keras.applications.ResNet152(include_top=False, weights='imagenet',input_shape=[*IMAGE_SIZE, 3])'''

In [None]:
def create_model():
    efficient_net = {
        0 : tf.keras.applications.EfficientNetB0(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
        1 : tf.keras.applications.EfficientNetB1(weights="imagenet",include_top=False ,input_shape=[*IMAGE_SIZE, 3]),
    }

    output = {}
    inputs = tf.keras.Input(shape=(*IMAGE_SIZE, 3))
    
    ls =   [0,1]      #[3,5,7]  
    
    '''
    for net in efficient_net.values():
        for layer in net.layers[1:]:
            layer.name =  "layer_"+str(i)
            i+=1'''
    
    
    
    
    for i in ls:
        pretrained_model = efficient_net[i]
        x = pretrained_model(inputs)
        x = tf.keras.layers.GlobalAveragePooling2D(name = "average_"+str(i))(x)
        output[i] = tf.keras.layers.Dense(CLASSES,activation="sigmoid", dtype='float32',name="dense_"+str(i))(x)
        del x,pretrained_model

    
    outputs = tf.keras.layers.average(list(output.values()))
    model = tf.keras.Model(inputs, outputs)
    
    return model

In [None]:
model = create_model()
model.save('trial.h5')

In [None]:
def compile_model(model, lr=0.0001):
    
    optimizer = tf.keras.optimizers.Adam(lr=lr)
    
    loss = tf.keras.losses.CategoricalCrossentropy()
        
    metrics = [
       tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy')
    ]

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    print('Compiled! ')
    return model

In [None]:
model = tf.keras.models.load_model('trial.h5')
model = compile_model(model)

In [None]:
def create_callbacks():
    
    cpk_path = './best_model.h5'
    
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=cpk_path,
        monitor='val_categorical_accuracy',
        mode='max',
        save_best_only=True,
        save_weights_only = True, #Saving weights only. 
                                  #You will need to create model again
                                  #wherever you need to predict from it.
        verbose=1,
    )

    reducelr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_categorical_accuracy',
        mode='max',
        factor=0.1,
        patience=3,
        verbose=0
    )

    earlystop = tf.keras.callbacks.EarlyStopping(
        monitor='val_categorical_accuracy',
        mode='max',
        patience=15, 
        verbose=1
    )
    
    callbacks = [checkpoint, reducelr, earlystop]         
    
    return callbacks

In [None]:
EPOCHS= 3
VERBOSE =1

STEPS_PER_EPOCH = NUM_TRAINING_IMAGES//(BATCH_SIZE)

tf.keras.backend.clear_session()
with strategy.scope():
    
    model = create_model()
    model = compile_model(model, lr=0.0001)
   
    callbacks = create_callbacks()
    
    history = model.fit(train_ds, 
                        epochs=EPOCHS,
                        callbacks=callbacks,
                        validation_data = valid_ds,
                        steps_per_epoch = STEPS_PER_EPOCH,
                        verbose=VERBOSE)

In [None]:
model2 = create_model()

model2.load_weights('./best_model.h5')



In [None]:
model2.predict(test_ds)

In [None]:
model.summary()

In [None]:
acc = history.history['categorical_accuracy']
val_acc = history.history['val_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(history.history['val_loss']))
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Categorical Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Categorical Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Categorical Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()