In [None]:
import os
import re
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf

from kaggle_datasets import KaggleDatasets

In [None]:
from tensorflow.keras.applications.efficientnet import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input as effnet_preprocess_input

In [None]:
import tensorflow_datasets as tfds
import tensorflow_hub as hub

In [None]:
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

### Data

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')

In [None]:
tfrec_fnames = tf.io.gfile.glob(GCS_PATH + '/train_tfrecords/ld_train*.tfrec')
len(tfrec_fnames)

### Data Prep

In [None]:
label_to_disease = pd.read_json(os.path.join(GCS_PATH, 'label_num_to_disease_map.json'), typ='series')

In [None]:
train_csv = pd.read_csv(os.path.join(GCS_PATH, 'train.csv'))

In [None]:
train_csv['disease'] = train_csv['label'].map(label_to_disease)
train_csv.head()

In [None]:
# 75:25 train:valid
train_fnames = tfrec_fnames[:12]
valid_fnames = tfrec_fnames[12:]
print(len(train_fnames), len(valid_fnames))

### Callbacks

In [None]:
early_stop = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', min_delta = 0.001, 
                                              patience = 5, mode = 'min', verbose = 1,
                                              restore_best_weights = True)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = 0.3, 
                                                 patience = 2, min_delta = 0.001, 
                                                 mode = 'min', verbose = 1)

### Datasets

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16 * tpu_strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]

In [None]:
def _parse_function(proto):
    # feature_description needs to be defined since datasets use graph-execution
    # - its used to build their shape and type signature
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'image_name': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'target': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
    }

    parsed_features = tf.io.parse_single_example(proto, feature_description)
    image = tf.image.decode_jpeg(parsed_features['image'], channels=3)
    image = tf.cast(image, tf.float32) # :: [0.0, 255.0]
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    target = tf.one_hot(parsed_features['target'], depth=5)
    return image, target

In [None]:
def load_dataset(tfrecords_fnames):
    raw_ds = tf.data.TFRecordDataset(tfrecords_fnames, num_parallel_reads=AUTO)
    parsed_ds = raw_ds.map(_parse_function, num_parallel_calls=AUTO)
    return parsed_ds

In [None]:
def build_train_ds(train_fnames, with_aug=False):
    ds = load_dataset(train_fnames)

    def data_augment(image, target):
        modified = tf.image.random_flip_left_right(image)
        modified = tf.image.random_flip_up_down(image)
        modified = tf.image.random_brightness(modified, 0.2)
        #modified = tf.image.random_contrast(modified, 0.2, 0.5)
        #modified = tf.image.random_hue(modified, 0.2)
        modified = tf.image.random_saturation(modified, 5, 10)
        modified = tf.clip_by_value(modified, 0.0, 255.0)
        return modified, target

    if with_aug:
        ds = ds.map(data_augment, num_parallel_calls=AUTO)

    return ds.repeat().shuffle(2048).batch(BATCH_SIZE, drop_remainder=True).prefetch(AUTO)

In [None]:
def build_valid_ds(valid_fnames):
    ds = load_dataset(valid_fnames)
    ds = ds.batch(BATCH_SIZE, drop_remainder=True).prefetch(AUTO)
    return ds

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(fname).group(1)) for fname in filenames]
    return np.sum(n)

n_train = count_data_items(train_fnames)
n_valid = count_data_items(valid_fnames)
train_steps = count_data_items(train_fnames) // BATCH_SIZE
print("TRAINING IMAGES:", n_train, ", STEPS PER EPOCH:", train_steps)
print("VALIDATION IMAGES:", n_valid)

### Plain Model

In [None]:
def preprocess_fn(image, label):
    image = image / 255.0
    image = tf.image.resize(image, (224, 224))
    label = tf.concat([label, [0]], axis=0)
    return image, label

In [None]:
train_fnames = tfrec_fnames[:12]
valid_fnames = tfrec_fnames[12:]

train_ds = load_dataset(train_fnames)
train_ds = train_ds.map(preprocess_fn, num_parallel_calls=AUTO)
train_ds = train_ds.repeat().shuffle(2048).batch(BATCH_SIZE, drop_remainder=True).prefetch(AUTO)

valid_ds = load_dataset(valid_fnames)
valid_ds = valid_ds.map(preprocess_fn, num_parallel_calls=AUTO)
valid_ds = valid_ds.batch(BATCH_SIZE, drop_remainder=True).prefetch(AUTO)

train_steps = count_data_items(train_fnames) // BATCH_SIZE

In [None]:
img, label = next(iter(train_ds))
print(img.numpy().max(), img.shape, img.dtype)

In [None]:
os.environ["TFHUB_CACHE_DIR"] = "/kaggle/working"
with tpu_strategy.scope():
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    cassava = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2', trainable=True, load_options=load_locally)
    model = tf.keras.Sequential([tf.keras.Input(shape=(224,224,3)),
                                 cassava])

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [None]:
model.fit(train_ds, validation_data=valid_ds,
          epochs=500, steps_per_epoch=train_steps,
          callbacks=[reduce_lr, early_stop])

In [None]:
model.save('cassava_model.h5')