In [1]:
## Import Library ##
import warnings 
warnings.filterwarnings('ignore')
import os

import math
import time
import random

import numpy as np
import pandas as pd

from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

from tqdm.auto import tqdm

import cv2

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR

# image processing
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, Flip, 
    IAAAdditiveGaussianNoise, Transpose, CenterCrop 
    )
from albumentations.pytorch import ToTensorV2

In [2]:
!pip install timm
import timm

In [3]:
## Config ##
class CFG:
    print_freq=100
    num_workers=4
    model_name='swin_base_patch4_window7_224'
    size = 224
    scheduler='CosineAnnealingLR'
    epochs= 5
    T_max=10 
    lr=1e-4
    min_lr=1e-6
    batch_size=32
    weight_decay=1e-6
    max_grad_norm=1000
    seed=42
    num_classes= 12
    target_col='species'
    n_fold=4
    trn_fold=[0, 1, 2, 3]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

In [5]:
## Dataset class ##
classes = {'Black-grass': 0,
         'Charlock': 1,
         'Cleavers': 2,
         'Common Chickweed': 3,
         'Common wheat': 4,
         'Fat Hen': 5,
         'Loose Silky-bent': 6,
         'Maize': 7,
         'Scentless Mayweed': 8,
         'Shepherds Purse': 9,
         'Small-flowered Cranesbill': 10,
         'Sugar beet': 11}

In [6]:
## Directory path ##
OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [7]:
## Load train dataset ##
train = pd.DataFrame(columns=['image_path','species', 'file'])
train.astype({'species': 'int32'})

pathToTrainData='../input/plant-seedlings-classification/train'

for dirname, _, filenames in tqdm(os.walk(pathToTrainData)):
    for filename in filenames:
        path = os.path.join(dirname, filename)
        class_label = dirname.split('/')[-1]
        class_label = classes[class_label]
        file = filename
        train = train.append({'file': file, 'image_path':path , 'species':class_label}, ignore_index = True)

train.astype({'species': 'int'}).dtypes
train.head(3)

In [8]:
## K-FOLD ##
folds = train.copy()
X, Y = folds['file'], folds[CFG.target_col].astype('int')
Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
Fold.get_n_splits(X)

for n, (train_index, val_index) in enumerate(Fold.split(X, Y)):
    folds.loc[val_index, 'fold'] = int(n)

folds['species'] = folds['species'].astype('int32')

In [9]:
## Dataset ##
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.file_path = df['image_path'].values
        self.df = df
        self.file_name = df['file'].values
        self.labels = df['species'].values
        self.transform = transform
        
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = f'{self.file_path[idx]}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label

In [10]:
## Transform ##
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(224,224),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [11]:
train_dataset = TrainDataset(train, transform=get_transforms(data='train'))

In [12]:
## Model ##
class customModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.fc_layer = nn.Linear(1000,CFG.num_classes)
        
    def forward(self, x):
        x = self.model(x)
        out = self.fc_layer(x)
        return out

In [13]:
## util function ##
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def get_score(y_true, y_pred):
    try:
        return accuracy_score(y_true, y_pred)
    except:
        print(type_of_target(y_true))
        sys.exit()

def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

In [14]:
## Train Function ##
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    losses = AverageMeter()
    scores = AverageMeter()
    
    model.train()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        y_preds = model(images)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1

        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(
                     epoch+1, step, len(train_loader),
                     loss=losses,
                   ))
    return losses.avg

In [15]:
## Validation Function ##
def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()
    scores = AverageMeter()

    model.eval()
    preds = []

    for step, (images, labels) in enumerate(valid_loader):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        with torch.no_grad():
            y_preds = model(images)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        
        preds.append(y_preds.softmax(1).to('cpu').numpy())

        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(
                   step, len(valid_loader),
                    loss=losses,
                   ))
    predictions = np.concatenate(preds)
    return losses.avg, predictions

In [16]:
def train_loop(folds, fold):

    LOGGER.info(f"========== fold: {fold} training ==========")
    
    # Data Loader 
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)

    train_dataset = TrainDataset(train_folds, 
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds, 
                                 transform=get_transforms(data='valid'))

    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=True, 
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    # model & scheduler
    model = customModel(CFG.model_name, pretrained=True)
    model.to(device)

    optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
    scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    criterion = nn.CrossEntropyLoss()

    # Train
    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
    
        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        avg_val_loss, preds = valid_fn(valid_loader, model, criterion, device)
        
        valid_labels = valid_folds[CFG.target_col].values
        
        scheduler.step()
        
        score = get_score(valid_labels, preds.argmax(1))

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Accuracy: {score}')

        if score > best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': preds},
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
    
    
    check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')

In [17]:
if __name__ == '__main__':
    for fold in range(CFG.n_fold):
        if fold in CFG.trn_fold:
            train_loop(folds,fold)