In [None]:
import os
import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.utils import to_categorical
import shutil
from glob import glob
import pandas_profiling as pp
import cv2
from google.cloud import storage
from kaggle_datasets import KaggleDatasets
from random import seed, randint, random, choice
from PIL import Image
import tensorflow_addons as tfa
import sys
from tensorflow.keras.optimizers import RMSprop, Adam, SGD
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

print("Tensorflow version " + tf.__version__)

In [None]:
# ViT
sys.path.append( '/kaggle/input/vit-keras-validators/vit_keras_validators' )
from vit_keras import vit, utils

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy


def build_decoder(with_labels=True, target_size=(256, 256), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")

        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True):
    def augment(img):
        
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)

        k90 = randint(0, 3)
        img = tf.image.rot90(img, k=k90)
        
        degrees = randint(-45, 45)
        img = tfa.image.transform_ops.rotate(img, degrees * 1 * 3.14 / 180.0, fill_mode='nearest')

        sigma = np.random.uniform(0.1, 0.9)
        fshape = np.random.randint(2, 7)
        img = tfa.image.gaussian_filter2d(img, sigma=sigma, filter_shape=[fshape, fshape])

        img = tf.image.random_brightness(img, 0.3)
        img = tf.image.random_contrast(img, 0.7, 1.3)
        img = tf.image.random_saturation(img, 0.7, 1.3)
        img = tf.image.random_hue(img, 0.1)
        
        img = tf.image.resize(img, (int(im_size * 2), int(im_size * 2)), method='nearest')
        img = tf.image.random_crop(img, [im_size, im_size, 3])
        
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=32, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

In [None]:
strategy = auto_select_accelerator()
BATCH_SIZE = strategy.num_replicas_in_sync * 16
GCS_DS_PATH = KaggleDatasets().get_gcs_path('plant-pathology-2021-fgvc8')

In [None]:
IMSIZES = (224, 304, 384)
im_size = IMSIZES[1]

load_dir = "/kaggle/input/plant-pathology-2021-fgvc8/"

df = pd.read_csv('/kaggle/input/plant-pathology-2021-traincsv-cleaned/train.csv')
# df = df.iloc[::-1]
# df = df.reset_index(drop=True)

# df = df.sample(frac=1, random_state=42)
# df = df.reset_index(drop=True)
 
one_hot = {'healthy': [0, 0, 0, 0, 0],
           'scab': [0, 1, 0, 0, 0],
           'scab frog_eye_leaf_spot': [0, 1, 1, 0, 0],
           'frog_eye_leaf_spot': [0, 0, 1, 0, 0],
           'rust': [0, 0, 0, 1, 0],
           'complex': [1, 0, 0, 0, 0],
           'powdery_mildew': [0, 0, 0, 0, 1],
           'rust frog_eye_leaf_spot': [0, 0, 1, 1, 0],
           'frog_eye_leaf_spot complex': [1, 0, 1, 0, 0],
           'scab frog_eye_leaf_spot complex': [1, 1, 1, 0, 0],
           'powdery_mildew complex': [1, 0, 0, 0, 1],
           'rust complex': [1, 0, 0, 1, 0]}

n_labels = 5    
    
class_name = df.labels.unique().tolist()
class_num = len(class_name)

print(class_name) 
print(class_num)

img_class = []
for c in class_name:
    img_class.append(df[df.labels==c]) 

df

In [None]:
train_paths = []
valid_paths = []
train_labels = []
valid_labels = []

TESTSIZE = 0.1

for i in range(class_num):
    size = int(img_class[i].shape[0] * TESTSIZE)
    valid_paths += list(GCS_DS_PATH + '/train_images/' + img_class[i][:size]['image'])
    train_paths += list(GCS_DS_PATH + '/train_images/' + img_class[i][size:]['image'])
    valid_labels += list(img_class[i][:size]['labels'])
    train_labels += list(img_class[i][size:]['labels'])
    
valid_labels = [one_hot[x] for x in valid_labels]
train_labels = [one_hot[x] for x in train_labels]

print()
print(len(train_paths))
print(len(valid_paths))
print('sum_len:', len(train_paths) + len(valid_paths))

In [None]:
decoder = build_decoder(with_labels=True, target_size=(im_size, im_size))

train_dataset = build_dataset(
    train_paths, train_labels, bsize=BATCH_SIZE, decode_fn=decoder
)

valid_dataset = build_dataset(
    valid_paths, valid_labels, bsize=BATCH_SIZE, decode_fn=decoder,
    repeat=False, shuffle=False, augment=False
)

In [None]:
from tensorflow.keras.optimizers import RMSprop, Adam, SGD
from keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Input, GlobalAveragePooling2D, ReLU, Flatten, Dense, Dropout, BatchNormalization, MaxPooling2D, GlobalMaxPooling2D
import tensorflow_addons

with strategy.scope():
    model = vit.vit_l16(
                  image_size=im_size,
                  activation='sigmoid',
                  pretrained=True,
                  include_top=True,
                  pretrained_top=False,
                  classes = n_labels,
                  weights = 'imagenet21k'
              )

    f1_score = tensorflow_addons.metrics.FBetaScore(num_classes=n_labels, 
                                                    threshold=None, beta=1.0,
                                                    name='f1_score')
    
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-1), 
                  metrics= [f1_score, tf.keras.metrics.AUC(multi_label=True)])
    
    model.summary()

In [None]:
checkpoint = ModelCheckpoint(f"bestvit.h5",
                             save_best_only=True,
                             save_weights_only=True,
                             monitor='val_loss',
                             mode='min',
                             verbose=1)

reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                              mode='min',
                              factor=0.8,
                              patience=40,
                              min_lr=1e-3,
                              verbose=1)

# es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=30)

In [None]:
steps_per_epoch = len(train_paths) // BATCH_SIZE
EPOCHS = 180

his = model.fit(
            train_dataset, 
            epochs=EPOCHS,
            verbose=1,
            callbacks=[checkpoint, reduce_lr],
            steps_per_epoch=steps_per_epoch,
            validation_data=valid_dataset)

model.save_weights('vit.h5')