In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import pandas as pd
import cv2
import numpy as np
!pip install -q ../input/keras-efficientnet-whl/Keras_Applications-1.0.8-py3-none-any.whl
!pip install -q ../input/keras-efficientnet-whl/efficientnet-1.1.1-py3-none-any.whl
import efficientnet.keras
import efficientnet
import time
import gc

In [None]:
test_img_dir='../input/cassava-leaf-disease-classification/test_images/'

In [None]:
model_root_path = '../input/cassava-gpu-preprocessed-trainer/'
model_name = 'effnet5'
model_0 = tf.keras.models.load_model(model_root_path +'fold_0_'+model_name+'_aug.hdf5')
model_1 = tf.keras.models.load_model(model_root_path +'fold_1_'+model_name+'_aug.hdf5')
model_2 = tf.keras.models.load_model(model_root_path +'fold_2_'+model_name+'_aug.hdf5')
model_3 = tf.keras.models.load_model(model_root_path +'fold_3_'+model_name+'_aug.hdf5')
model_4 = tf.keras.models.load_model(model_root_path + 'fold_4_'+model_name+'_aug.hdf5')

In [None]:

# pipeline:
#  image_name -> predict ->store to csv 
#  predict = (image_generator  -> make_prediction ) 
#  image_generator = ( -> read_img -> augment )


def read_img(image_name):
    filepath = '../input/cassava-leaf-disease-classification/test_images/'+image_name
    image = tf.io.decode_jpeg(tf.io.read_file(filepath))
    return image


def get_augment_list():
    return np.array(list(map(lambda x:x<1,np.random.randint(2, size=6))),dtype='bool')


@tf.function
def resize_image(image):
    return tf.cast(tf.reshape(tf.image.resize(image,[299,299]),(1,299,299,3)),tf.float32)

    
    
@tf.function
def augment_img_randomly(img):
    '''
    Augmentaions to be used: (use stateless versions of these)
    
    Random hue (0.2)
    Random brightness (0.3)
    Random saturation (0.7,1.3)
    Random contrast  (0.8,1.2)
    ''' 
    augment_list = get_augment_list()
    image = resize_image(img)
     #(32,512,512,3)
    
    if augment_list[0]:
        image = tf.image.random_saturation(image,0.7,1.3)
    if augment_list[1]:
        image = tf.image.random_contrast(image,0.8,1.2)
    if augment_list[2]:
        image = tf.image.random_brightness(image,0.3)
    if augment_list[3]:
        image = tf.image.random_hue(image,0.2)
    if augment_list[4]:
        image = tf.image.random_flip_left_right(image)
    if augment_list[5]:
        image = tf.image.random_flip_up_down(image)
    
    
    
    image = tf.math.divide(image,255)
    del augment_list,img
    gc.collect()
        
    return image

def image_generator(image_name):
    img = read_img(image_name)
    img = augment_img_randomly(img)
    return img

In [None]:


def make_image(image_name):
    img = tf.image.decode_jpeg(tf.io.read_file('../input/cassava-leaf-disease-classification/test_images/'+image_name))
    img = tf.image.resize(img,[299,299])
    img = tf.cast(img,tf.float32)
    img = tf.math.divide(img,tf.constant(255.0))
    img = tf.reshape(img,[1,299,299,3])
    return img


#@tf.function
def make_prediction(img):
    
     
    
    label_arr_0 = model_0.predict(img)
    
    
    label_arr_1 = model_1.predict(img)
    
    
    label_arr_2 = model_2.predict(img)
    
    
    label_arr_3 = model_3.predict(img)
    
    
    label_arr_4 = model_4.predict(img)
    
    
    label_arr = label_arr_0 + label_arr_1 + label_arr_2 + label_arr_3 + label_arr_4 
    return tf.math.argmax(label_arr,axis=-1)

def predict(img_name):
    #imgpath = tf.constant(img_path)
    return make_prediction(image_generator(img_name)).numpy()[0] 

In [None]:
images = []
labels = []
for i in os.listdir(test_img_dir):
    image_name = i
    #label = predict('2216849948.jpg')
    label = make_prediction(make_image(i)).numpy()[0]
    images.append(image_name)
    labels.append(label)


df=pd.DataFrame(data={'image_id':images,'label':labels})

df.to_csv('submission.csv',index=False)    
    

In [None]:
# start = time.time()
# print(make_prediction(make_image(tf.constant('../input/cassava-leaf-disease-classification/test_images/2216849948.jpg'))))
# mid = time.time()
# print("Elapsed: ",mid-start)

# print(make_prediction(make_image(tf.constant('../input/cassava-leaf-disease-classification/test_images/2216849948.jpg'))))
# endd = time.time()
# print("Elapsed: ",endd-mid)