In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# 1 | Import Libraries

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_hub as hub
from PIL import Image

# 2 | Load Data

In [None]:
root_dir = "/kaggle/input/cassava-leaf-disease-merged/train/"
files = os.listdir(root_dir)
files = [root_dir + file for file in files]
files = np.array(files)

In [None]:
df = pd.read_csv('/kaggle/input/cassava-leaf-disease-merged/merged.csv', usecols=['image_id', 'label'])
df.head()

In [None]:
df.info()

In [None]:
df['image_id'] = df['image_id'].apply(lambda x: root_dir + x)

In [None]:
df.head()

In [None]:
df['label'] = df['label'].map({
    0: "Bacterial Blight",
    1: "Brown Streak",
    2: "Green Mottle",
    3: "Mosaic Disease",
    4: "Healthy"
})

# 3 | Visualization

In [None]:
def visualize_df(df: np.ndarray):
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    for i, ax in enumerate(axes.ravel()):
        if i < len(df):
            img_path = df.loc[i][['image_id']].values[0]
            img_label = df.loc[i][['label']].values[0]
            image = Image.open(img_path).convert('RGB')
            ax.imshow(image)
            ax.set_title(img_label)
            ax.axis("off")
        else:
            ax.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
visualize_df(df)

# 4 | Train Test Split

In [None]:
def train_validate_test_split(df):
    df = df.sample(frac=1)
    train_end = int(.6 * len(df.index))
    validate_end = int(.2 * len(df.index)) + train_end
    train = df.iloc[:train_end]
    validate = df.iloc[train_end:validate_end]
    test = df.iloc[validate_end:]
    return train, validate, test

df_train, df_validation, df_test = train_validate_test_split(df)
print("Train: ", len(df_train), "\nValidation: ", len(df_validation), "\nTest: ", len(df_test))

# 5 | Preprocess

In [None]:
def preprocess_func(data):
    image_path = data['image_id']

    image = tf.io.read_file(image_path)
    image = tf.image.decode_image(image, channels=3)

    image.set_shape([None, None, 3])
    image = tf.image.resize(image, (224, 224))
    
    image = tf.cast(image, tf.float32) / 255.0

    data['image_id'] = image
    return data

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((df_train.to_dict('list')))
test_dataset = tf.data.Dataset.from_tensor_slices((df_test.to_dict('list')))
validation_dataset = tf.data.Dataset.from_tensor_slices((df_validation.to_dict('list')))

In [None]:
processed_validation = validation_dataset.map(preprocess_func).batch(25).as_numpy_iterator()
examples = next(processed_validation)

# 6 | Model

In [None]:
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image_id'])
predictions = tf.argmax(probabilities, axis=-1)

# 7 | Predictions

In [None]:
cassava_labels = ["Bacterial Blight", "Brown Streak", "Green Mottle", 
                  "Mosaic Disease", "Healthy"]

In [None]:
def plot_predictions(data):
    images = data['image_id']
    labels = data['label']
    fig, axes = plt.subplots(5, 5, figsize=(15, 15))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i])
        ax.axis('off')
        ax.set_title(f"Original: {str(labels[i])}\nPredicted: {cassava_labels[predictions[i]]}")
    plt.tight_layout()
    plt.show()

In [None]:
plot_predictions(examples)