# Import

In [1]:
import numpy as np
import pandas as pd
import timm
import random
import os
from PIL import Image
import torch

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
import tqdm
from torch import nn
from torch import optim

# Set seed (Reproduce result)

In [2]:
SEED = 123456789

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Split data

In [3]:
data = pd.read_csv('cassava-leaf-disease-classification/train.csv')
choosen_prob = np.random.rand(len(data))
train_df = data[choosen_prob >= 0.2]
val_df = data[choosen_prob < 0.2]

In [4]:
len(val_df)

4294

In [5]:
len(train_df)

17103

# Hyper params

In [6]:
EXP_NAME = 'exp_01'

###########################################
MODEL_NAME = 'tf_efficientnet_b0_ns'
IM_SIZE = 300
BATCH_SIZE = 56
LEARNING_RATE = 1e-1
LR_STEP = 30
EPOCH = 1
PRINT_STEP = 30
WARMUP_EPOCH = 10
WEIGHT_DECAY = 1e-4
##########################################
VAL_BATCH_SIZE = 8
VAL_HEIGHT = 300
VAL_WIDTH = 400
############################################
NUM_WORKER = 16

# Data augmentation

In [7]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

train_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.RandomResizedCrop((IM_SIZE, IM_SIZE)),
     transforms.RandomRotation(90),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize(mean, std)])

val_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((VAL_HEIGHT, VAL_WIDTH), interpolation=Image.BICUBIC),
     transforms.Normalize(mean, std)])

# Dataset

In [8]:
class CassavaDataset(Dataset):
    def __init__(self, image_dir, df, transform=None):
        self.image_dir = image_dir
        self.df = df.reset_index(drop=True)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.loc[index]
        label = row.label
        image_name = row.image_id
        
        image = Image.open(os.path.join(self.image_dir, image_name))
        image = np.array(image)
        
        if self.transform is not None:
            image = self.transform(image)
        
        return image, label

In [9]:
train_image_dir = 'cassava-leaf-disease-classification/train_images'
train_dataset = CassavaDataset(train_image_dir, train_df, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True, 
                                           num_workers=NUM_WORKER)

val_image_dir = 'cassava-leaf-disease-classification/train_images'
val_dataset = CassavaDataset(val_image_dir, val_df, transform=val_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, 
                                           batch_size=VAL_BATCH_SIZE, 
                                           shuffle=False, 
                                           num_workers=NUM_WORKER)

# Model

In [10]:
device = torch.device("cuda")
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=5)

In [11]:
criterion = nn.CrossEntropyLoss()
linear_scaled_lr = 8.0 * LEARNING_RATE * BATCH_SIZE / 512.0
optimizer = optim.SGD(model.parameters(), lr=linear_scaled_lr, momentum=0.9, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP)
model.cuda()
print('Convert model to CUDA')

Convert model to CUDA


# Train and Val

In [12]:
def freeze_bn(model):
    model.eval()
    model.classifier.train()
    model.conv_head.train()
    model.bn2.train()

In [13]:
def train_one_epoch(epoch, model, train_loader, criterion, optimizer, finetune_on_bn=True, logs_file=None):
    if finetune_on_bn == False:
        freeze_bn(model)
    else:
        model.train()
        
    running_loss = 0.0
    scaler = torch.cuda.amp.GradScaler()
    for i, data in enumerate(tqdm.tqdm(train_loader), 0):
        inputs, labels = data
        with torch.cuda.amp.autocast():
            inputs = inputs.cuda()
            labels = labels.cuda()

            outputs = model(inputs)
            assert outputs.dtype is torch.float16
            loss = criterion(outputs, labels)
            assert loss.dtype is torch.float32
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        running_loss += loss.item()
        if logs_file != None:
            if i % PRINT_STEP == PRINT_STEP - 1:
                logs_file.write('Training loss at epoch {}, iteration {}: {}\n'.format(epoch, i, running_loss / i))
    if logs_file != None:
        logs_file.write('------------------------------------------------\n')
                            
    return running_loss / len(train_loader)

In [14]:
def validate(epoch, model, val_loader, logs_file=None):
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in tqdm.tqdm(val_loader):
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
        
            outputs = model(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            running_loss += loss.item()
    
    if logs_file != None:
        logs_file.write('Validating loss at epoch {}: {}\n'.format(epoch, running_loss / len(val_loader)))
        logs_file.write('Validating accuracy at epoch {}: {}\n'.format(epoch, correct / total))
        logs_file.write('**************************************************\n')
    return running_loss / len(val_loader), correct / total

# def frezze_top(Model):
#     pass

In [15]:
import time

train_loss_list = []
val_acc_list = []
val_loss_list = []
with open('logs_{}.txt'.format(EXP_NAME), 'w') as logs_file:
    for epoch in range(EPOCH):
        train_loss = 0.0
        if epoch < WARMUP_EPOCH:
            train_loss = train_one_epoch(epoch, 
                                         model, 
                                         train_loader, 
                                         criterion, 
                                         optimizer, 
                                         finetune_on_bn=False, 
                                         logs_file=logs_file)
        else:
            train_loss = train_one_epoch(epoch, 
                                         model, 
                                         train_loader, 
                                         criterion, 
                                         optimizer, 
                                         logs_file=logs_file)
        val_loss, val_acc = validate(epoch, 
                                     model, 
                                     val_loader, 
                                     logs_file=logs_file)

        train_loss_list.append(train_loss)
        val_acc_list.append(val_acc)
        val_loss_list.append(val_loss)

100%|██████████| 306/306 [03:18<00:00,  1.54it/s]
100%|██████████| 537/537 [00:58<00:00,  9.15it/s]
100%|██████████| 306/306 [03:14<00:00,  1.57it/s]
 34%|███▍      | 182/537 [00:21<00:41,  8.60it/s]


KeyboardInterrupt: 

# Visualize train loss and val loss

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

print(train_loss_list)
print(val_acc_list)
print(val_loss_list)

plt.plot(train_loss_list)
plt.plot(val_acc_list)
plt.plot(val_loss_list)

plt.show()