# Cassava Starter (Inference)

**This is the inference kernel for [this notebook](https://www.kaggle.com/tuckerarrants/cassava-tensorflow-starter-training). I also included the ability to validate the pre-trained model, but this requires TPU connection and hence will not work for submission. You must validate first with internet on and then turn internet off for submission.**

In [None]:
import sys
sys.path.append('/kaggle/input/efficientnet-keras-dataset/efficientnet_kaggle')

In [None]:
import numpy as np
import pandas as pd 
import math, re, os
import random
import gc
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from tensorflow.keras import backend as K
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
print("Tensorflow version " + tf.__version__)
from sklearn.metrics import accuracy_score
import efficientnet.tfkeras as efn
from collections import Counter
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    DEVICE = "TPU"
except:
    DEVICE = "notTPU"
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

REPLICAS =  strategy.num_replicas_in_sync
FILENAMES = tf.io.gfile.glob("../input/cassava-leaf-disease-classification" + '/test_tfrecords/ld_test*.tfrec')
BATCH_SIZE = 128 * strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]
classes = ['0', '1', '2', '3', '4']  

os.environ['PYTHONHASHSEED']=str(34)
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
random.seed(34)
np.random.seed(34)
tf.random.set_seed(34)

# Helper Functions

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "target": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['target'], tf.int32)
    
    return image, label

def read_unlabeled_tfrecord(example, return_image_name):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # [] means single entry
    }
    
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum if return_image_name else 0

def get_val_dataset(files, one_hot = False,
                    shuffle = False, repeat = False, 
                    labeled = True, return_image_names = True,
                    batch_size = BATCH_SIZE, dim = IMAGE_SIZE[0]):
   
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTOTUNE)

    if repeat:
        ds = ds.repeat()
    
    if shuffle: 
        ds = ds.shuffle(2048)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)
        
    if labeled: 
        ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
    else:
        ds = ds.map(lambda example: read_unlabeled_tfrecord(example, return_image_names), 
                    num_parallel_calls=AUTOTUNE) 

    if one_hot:
        ds = ds.map(onehot, num_parallel_calls = AUTOTUNE)

    ds = ds.batch(batch_size)
    ds = ds.prefetch(AUTOTUNE)
    
    return ds

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def onehot(image,label):
    CLASSES = len(classes)
    return image,tf.one_hot(label,CLASSES)

# Validation

In [None]:
VALIDATE = False

In [None]:
from kaggle_datasets import KaggleDatasets

FOLDS=5
SEED=34

if VALIDATE: 
    GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-tfrecords-512x512')
    TRAINING_FILENAMES =  tf.io.gfile.glob(GCS_PATH + '/*.tfrec')
    AUG_TYPE = 'CUTMIXUP'

In [None]:
from sklearn.model_selection import KFold

if VALIDATE:
    histories = []
    oof_pred = []; oof_labels = []
    kfold = KFold(FOLDS, shuffle = True, random_state = SEED)

    for f, (train_index, val_index) in enumerate(kfold.split(TRAINING_FILENAMES)):

        print('#'*25); print('FOLD',f+1); print('#'*25); print('')
        print('Getting datasets...'); print('')  

        val_ds = get_val_dataset(list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_index]['TRAINING_FILENAMES']),
                             one_hot=True,labeled=True, return_image_names=False, repeat=False, shuffle=False)

        #and go!
        print('Getting model...'); print(''); print('Training model...'); print('')

        model = tf.keras.models.load_model(f'../input/cassava-tensorflow-starter-training/EFFNET_{f}_34_CUTMIXUP_512_full.h5')

        #get OOF predictions
        oof_labels.append([target.numpy() for img, target in iter(val_ds.unbatch())])
        x_oof = val_ds.map(lambda image, image_name: image)
        oof_pred.append(np.argmax(model.predict(x_oof), axis=-1))

        del model; z = gc.collect()

In [None]:
if VALIDATE:
    y_true = np.concatenate(oof_labels)
    y_preds = np.concatenate(oof_pred)

    print(classification_report(np.argmax(y_true, axis=1) if AUG_TYPE is 'CUTMIXUP' else y_true, y_preds))
    print(f"OOF accuracy score: {accuracy_score(np.argmax(y_true, axis=1) if AUG_TYPE is 'CUTMIXUP' else y_true, y_preds)}")

# Inference

In [None]:
JPEG_PATH = "../input/cassava-leaf-disease-classification/test_images"
JPEG_PATH_TR = "../input/cassava-leaf-disease-classification/train_images"

import cv2
from tqdm.notebook import tqdm

def load_image(jpeg_path, image_id):
    img = cv2.imread(os.path.join(jpeg_path, image_id))/255.0
    img = cv2.resize(img, (512, 512))[:, :, ::-1]

    return img
def generator(filepath, paths, batch_size=32):
    i=0
    print(len(paths))
    while i <= len(paths):
        batch = []
        for cpt in range(batch_size):
            if i + cpt >= len(paths):
                i += batch_size
                break
            batch.append(load_image(filepath, paths[i+cpt]))
            
        i += batch_size
        yield np.stack(batch)

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
tr = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

In [None]:
preds_all = []

preds_model = []
for fold in range(FOLDS):
    print(f"## FOLD: {fold}")

    ds_test = generator(JPEG_PATH,submission.image_id.values)

    K.clear_session()

    print('Loading and inferring...')
    model = tf.keras.models.load_model(f'../input/cassava-tensorflow-starter-training/EFFNET_{fold}_34_CUTMIXUP_512_full.h5')

    preds = model.predict(ds_test, verbose=True)
    preds_model.append(preds)                 

preds_model = np.stack(preds_model).mean(0)
preds_all.append(preds_model)

preds_all = np.stack(preds_all)

In [None]:
print(preds_all.shape)
preds_all

In [None]:
submission["label"] = preds_all.mean(0).argmax(1)
submission.to_csv("submission.csv", index=False)

In [None]:
submission