# Cassava Leaf Diseases Classification using DeiT
*by Pio Mendoza*

## Load Modules

In [None]:
!pip install timm

In [25]:
from datetime import datetime
from pathlib import Path
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.optim import AdamW 
from torch.optim.lr_scheduler import ExponentialLR 
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from tqdm.notebook import tqdm

import gc
import numpy as np
import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms

## Constants

In [2]:
IMAGE_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

Device: cpu


## Plotting Functions

In [3]:
toPIL = transforms.ToPILImage()

def plot_batch(images):
    """
    Plots one batch of images. (B, C, H, W)
    """
    images = make_grid(images, nrow=4, scale_each=True, pad_value=1)

    plt.figure(figsize=(15,15))
    plt.imshow(toPIL(images))
    plt.axis("off");

## Load Cassava Leaf Dataset

### Download Dataset

In [None]:
!mkdir -p data
!wget "https://storage.googleapis.com/emcassavadata/cassavaleafdata.zip" -P data
!unzip data/cassavaleafdata.zip -d data

### Data Transforms

In [6]:
train_augs = transforms.Compose([
    transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
])


valid_augs = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],),
])

In [7]:
train_data = ImageFolder("data/cassavaleafdata/train", transform=train_augs)

class_sample_counts = np.empty(len(train_data.classes), dtype=int)
for idx in range(len(train_data.classes)):
    class_sample_counts[idx] = train_data.targets.count(idx)

weights = 1. / torch.tensor(class_sample_counts, dtype=torch.float)
samples_weights = weights[train_data.targets]

# data balancer
sampler = data.WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),
    replacement=True)

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)

In [8]:
validation_data = ImageFolder("data/cassavaleafdata/validation", transform=valid_augs)
validation_dataloader = DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [9]:
test_data = ImageFolder("data/cassavaleafdata/test", transform=valid_augs)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [None]:
del train_data, validation_data, test_data
gc.collect()

102

## Train and Evaluate Functions

In [26]:
def train_one_epoch(model, device, train_loader, optimizer, criterion, epoch, epochs, scheduler=None):
    running_train_loss = 0.0
    running_train_correct_predictions = 0
    num_items = 0
    model.train()
    with tqdm(train_loader, total=len(train_dataloader)) as loop:
        optimizer
        for data in loop:
            optimizer.zero_grad()

            inputs, labels = data[0].to(device), data[1].to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            num_items += inputs.size(0)
            running_train_loss += loss.item()
            avg_loss = running_train_loss / num_items

            pred = outputs.argmax(dim=1, keepdim=True)
            running_train_correct_predictions += pred.eq(labels.view_as(pred)).sum().item()
            avg_accuracy = running_train_correct_predictions * 100 / num_items

            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss=avg_loss, acc=avg_accuracy, lr=optimizer.param_groups[0]['lr'])

        
        if scheduler:
            scheduler.step()

    return running_train_loss, running_train_correct_predictions

def validate_one_epoch(model, device, validation_loader, criterion, epoch, epochs):
    running_validation_loss = 0.0
    running_validation_correct_predictions = 0
    num_items = 0
    model.eval()
    with tqdm(validation_loader, total=len(validation_loader)) as loop:
        with torch.no_grad():
            for data in loop:
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                num_items += inputs.size(0)
                running_validation_loss += loss.item()
                avg_loss = running_validation_loss / num_items
                
                pred = outputs.argmax(dim=1, keepdim=True)
                running_validation_correct_predictions += pred.eq(labels.view_as(pred)).sum().item()
                avg_accuracy = running_validation_correct_predictions * 100 / num_items
                
                loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
                loop.set_postfix(val_loss=avg_loss, val_acc=avg_accuracy)

    return running_validation_loss, running_validation_correct_predictions

def fit(model,  epochs, device, train_loader, validation_loader, optimizer, criterion, writer, scheduler=None):

    best_validation_loss = float('inf')
    for epoch in range(epochs):
        running_train_loss, running_train_accuracy = train_one_epoch(model, device, train_loader, optimizer, criterion, epoch, epochs, scheduler)
        running_validation_loss, running_validation_accuracy = validate_one_epoch(model, device, validation_loader, criterion, epoch, epochs)
        avg_train_loss = running_train_loss / len(train_loader.dataset)
        avg_validation_loss = running_validation_loss  / len(validation_loader.dataset)
        avg_train_accuracy = running_train_accuracy / len(train_loader.dataset)
        avg_validation_accuracy = running_validation_accuracy / len(validation_loader.dataset)
        writer.add_scalars("Training vs Validation Loss",
            {"Training": avg_train_loss, "Validation": avg_validation_loss},
            epoch+1
        )
        writer.add_scalars("Training vs Validation Accuracy",
            {"Training": avg_train_accuracy, "Validation": avg_validation_accuracy},
            epoch+1
        )
        writer.flush()

        if avg_validation_loss < best_validation_loss:
            best_validation_loss = avg_validation_loss
            model_path = 'models/model_{}_{}'.format(TIME_STAMP, epoch+1)
            torch.save(model.state_dict(), model_path)


## Create Model

In [11]:
model = timm.create_model("deit_small_patch16_224", pretrained=True, num_classes=5)
model.to(DEVICE);

In [17]:
TIME_STAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f"runs/{TIME_STAMP}")

criterion = nn.CrossEntropyLoss()

lr = 1e-4
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-3)

scheduler = ExponentialLR(optimizer, gamma=0.965)

epochs = 10

In [27]:
fit(model,  epochs, DEVICE, train_dataloader, validation_dataloader, optimizer, criterion, writer, scheduler)

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html