In [None]:
import os

import albumentations as A
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator


class Transform():
    def __init__(self):
        self.aug = A.Compose([
            A.RandomRotate90(),
            A.Rotate(),
            A.Flip(),
        ])

    def __call__(self, image):
        return self.aug(image=image)['image']


INPUT_DIR = '/kaggle/input/cassava-leaf-disease-classification'
WORKING_DIR = '/kaggle/working'

os.makedirs('checkpoints', exist_ok=True)

df = pd.read_csv(os.path.join(INPUT_DIR, 'train.csv'))
df['label'] = df['label'].astype(str)

train_df, validation_df = train_test_split(
    df,
    test_size=0.1,
    random_state=42,
    stratify=df[['label']],
)
test_df = pd.DataFrame({
    'image_id': list(os.listdir(os.path.join(INPUT_DIR, 'test_images'))),
    'label': '0',
})
test_df.to_csv('submission.csv', index=False)

batch_size = 16
train_datagen = ImageDataGenerator(preprocessing_function=Transform())\
    .flow_from_dataframe(
        dataframe=train_df,
        directory=os.path.join(INPUT_DIR, 'train_images'),
        x_col='image_id',
        y_col='label',
        batch_size=batch_size,
    )
validation_datagen = ImageDataGenerator()\
    .flow_from_dataframe(
        dataframe=validation_df,
        directory=os.path.join(INPUT_DIR, 'train_images'),
        x_col='image_id',
        y_col='label',
        batch_size=batch_size,
    )

model = Sequential([
    EfficientNetB0(include_top=False),
    GlobalAveragePooling2D(),
    Dense(units=5, activation='softmax'),
])

model.compile('adam', 'categorical_crossentropy', ['accuracy'])
model.fit(
    train_datagen,
    epochs=25,
    validation_data=validation_datagen,
    workers=4,
    verbose=2,
)

model.save(os.path.join(WORKING_DIR, 'cassava.h5'), save_format='h5')