In [None]:
!pip install -q efficientnet_pytorch

In [None]:
import os
import time

import numpy as np 
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations import pytorch as ATorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split

In [None]:
# TODO: add configs

CFG = {
    "seed": 42,
    "batch_size": 16,
    "test_size" : 0.2,
    "learning_rate" : 0.001,
    "batch_size" : 16,
    "epochs" : 25,
    
}

In [None]:
IMG_DIR = "../input/cassava-leaf-disease-classification/train_images/"
ANNOTATIONS_FILE = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")
MODEL_SAVE_PATH = "best_model.torch"

In [None]:
# TODO: Remove preprocess

class CassavaImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = annotations_file
        self.img_dir = img_dir
        self.transform = transform
       
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image=image)['image']
        
        return image, label

In [None]:
# TODO: Change to effnet
class EfficientNetModel(nn.Module):
    def __init__(self, n_classes=5):
        super().__init__()
        self.net = EfficientNet.from_pretrained('efficientnet-b0')
        self.net.classifier = nn.Linear(1280, n_classes)

    def forward(self, x):
        return self.net(x)

In [None]:
class Trainer:
    def __init__(self, model, device, score, loss, optimizer = None):
        self.model = model
        self.device = device
        self.score = score
        self.loss = loss
        self.optimizer = optimizer
        
    def run(self, train_dataloader, val_dataloader, epochs, save_path):
        best_score = 0
        
        for epoch in range(epochs):
            train_loss, train_score, train_time = self.train_epoch(train_dataloader)
            val_loss, val_score, valid_time = self.val_epoch(val_dataloader)
            
            print(
                f"Epoch {epoch+1}",
                f"Train Loss: {train_loss:.3f}, Train Accuracy: {train_score:.3f}, Time: {train_time} sec.",
                f"Validation Loss: {val_loss:.3f}, Validation Accuracy: {val_score:.3f}, Time: {valid_time} sec.",
                f"------------------------------",
                sep="\n",
            )
            
            if best_score < val_score:
                best_score = val_score
                self.save_model(save_path)
                
        
    def train_epoch(self, dataloader):
        self.model.train()
        t = time.time()
        train_loss, train_accuracy = 0, 0
        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(self.device), y.to(self.device)
            pred = self.model(X)
            loss = self.loss(pred, y)
            accuracy = self.score(pred.detach().cpu().numpy(), y.detach().cpu().numpy())
            train_accuracy, train_loss = update_metrics(train_accuracy, accuracy, train_loss, loss, batch)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        return train_loss, train_accuracy, int(time.time() - t)

        
    def val_epoch(self, dataloader):
        self.model.eval()
        t = time.time()
        val_loss, val_accuracy = 0, 0

        with torch.no_grad():
            for batch, (X, y) in enumerate(dataloader):
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                loss = self.loss(pred, y)
                accuracy = self.score(pred, y)
                val_accuracy, val_loss = update_metrics(val_accuracy, accuracy, val_loss, loss, batch)
                
        return val_loss, val_accuracy, int(time.time() - t)

    
    def save_model(self, save_path):
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
            },
            save_path,
        )
        
        
        

In [None]:
# TODO: change to albumentations
# Use only A library
# TODO: get_train_transforms, get_valid_transforms

def get_train_transforms():
    return A.Compose(
        [
            A.Resize(224, 224),            
            A.Rotate(limit=30, border_mode=cv2.BORDER_REPLICATE, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Blur(p=0.25),
            A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                p=1.0
            ),
            ATorch.transforms.ToTensorV2(p=1.0),
        ],
        p=1.0
    )

def get_valid_transforms():
    return A.Compose(
        [
            A.Resize(224, 224),
            A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                p=1.0
            ),
            ATorch.transforms.ToTensorV2(p=1.0),
        ],
        p=1.0
    )

In [None]:
def accuracy_fn(y_pred, y):
    return (y_pred.argmax(1) == y).sum().item() / y.shape[0]

In [None]:
def update_metrics(mean_score, score, mean_loss, loss, step):
    mean_score = (mean_score * step + score)/(step+1)
    mean_loss = (mean_loss * step + loss.detach().cpu().item())/(step+1)
    return mean_score, mean_loss

In [None]:
train, val = train_test_split(ANNOTATIONS_FILE, test_size=CFG["test_size"], random_state=CFG["seed"], shuffle=True)

In [None]:
train_dataset = CassavaImageDataset(train, IMG_DIR, transform=get_train_transforms())
val_dataset = CassavaImageDataset(val, IMG_DIR, transform=get_valid_transforms())

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=CFG["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=CFG["batch_size"])

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = EfficientNetModel().to(device)

loss_fn  = nn.CrossEntropyLoss()
optimizer =  torch.optim.SGD(model.parameters(), lr=CFG["learning_rate"])

In [None]:
trainer = Trainer(model, device, accuracy_fn, loss_fn, optimizer)
trainer.run(train_dataloader, val_dataloader, CFG["epochs"], MODEL_SAVE_PATH)