In [None]:
import gc
import os
import cv2
import sys
import json
import time
import timm
import torch
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2

os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
!nvidia-smi

In [None]:
metadata = pd.read_csv("../metadata/SnakeCLEF2021_train_metadata_PROD.csv")

print(len(metadata))

# Change your "image_path"!
metadata["image_path"] = metadata["image_path"].apply(lambda x: x.replace("/Datasets/SnakeCLEF-2021/", '/local/nahouby/Datasets/SnakeCLEF2021/'))
metadata.head(5)

In [None]:
train_metadata = metadata
print(len(train_metadata))

In [None]:
@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')

    
def init_logger(log_file='train.log'):
    from logging import getLogger, DEBUG, FileHandler,  Formatter,  StreamHandler
    
    log_format = '%(asctime)s %(levelname)s %(message)s'
    
    stream_handler = StreamHandler()
    stream_handler.setLevel(DEBUG)
    stream_handler.setFormatter(Formatter(log_format))
    
    file_handler = FileHandler(log_file)
    file_handler.setFormatter(Formatter(log_format))
    
    logger = getLogger('Herbarium')
    logger.setLevel(DEBUG)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

LOG_FILE = '../logs/SnakeCLEF2021-ViT_large_patch16-384-FocalLoss-FT.log'
LOGGER = init_logger(LOG_FILE)


def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [None]:
N_CLASSES = 772

class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['image_path'].values[idx]
        continent = self.df['continent'].values[idx]
        country = self.df['country'].values[idx]
        class_id = self.df['class_id'].values[idx]
        image = cv2.imread(file_path)
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except:
            print(file_path)
                    
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
                
        return image, file_path, class_id, country, continent

In [None]:
def getModel(architecture_name, target_size, pretrained = False):
    net = timm.create_model(architecture_name, pretrained=pretrained)
    net_cfg = net.default_cfg
    last_layer = net_cfg['classifier']
    num_ftrs = getattr(net, last_layer).in_features
    setattr(net, last_layer, nn.Linear(num_ftrs, target_size))
    return net

In [None]:
# %%
MODEL_NAME = 'vit_large_patch16_384'
model = getModel(MODEL_NAME, N_CLASSES, pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])

model.load_state_dict(torch.load('SnakeCLEF2021-ViT_large_patch16-384-50E.pth'))

In [None]:
HEIGHT = 384
WIDTH = 384

from albumentations import HueSaturationValue, RandomCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast, CenterCrop, PadIfNeeded, RandomResizedCrop, ShiftScaleRotate, Blur, JpegCompression, RandomShadow

def get_transforms(*, data):
    assert data in ('train', 'valid')

    if data == 'train':
        return Compose([
            RandomResizedCrop(WIDTH, HEIGHT, scale=(0.7, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.25, rotate_limit=45, p=.75),
            JpegCompression(quality_lower=50, quality_upper=100),
            #RandomShadow(),
            Blur(blur_limit=2),
            RandomBrightnessContrast(p=0.3),
            HueSaturationValue(p=0.2),
            Normalize(
                mean=model_mean,
                std=model_std,
            ),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(WIDTH, HEIGHT),
            Normalize(mean = model_mean, std = model_std),
            ToTensorV2(),
        ])

In [None]:
# Adjust BATCH_SIZE and ACCUMULATION_STEPS to values that if multiplied results in 64 !!!!!1
BATCH_SIZE = 4
ACCUMULATION_STEPS = 64
EPOCHS = 20
WORKERS = 16

train_dataset = CustomDataset(train_metadata, transform=get_transforms(data='train'))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, logits, target):

        ce_loss = F.cross_entropy(logits, target, reduction=self.reduction, weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score
import tqdm
from torch.utils.tensorboard import SummaryWriter

with timer('Train model'):
    accumulation_steps = ACCUMULATION_STEPS
    n_epochs = EPOCHS
    lr = 0.0001
    
    model.to(device)
    
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=int(len(train_loader) / ACCUMULATION_STEPS), epochs=EPOCHS)    
    
    criterion = FocalLoss()
    best_score = 0.
    best_loss = np.inf
    best_f1 = 0.
    
    for epoch in range(n_epochs):
        
        start_time = time.time()

        model.train()
        avg_loss = 0.

        optimizer.zero_grad()

        for i, (images, _, labels, _, _) in tqdm.tqdm(enumerate(train_loader)):

            images = images.to(device)
            labels = labels.to(device)

            y_preds = model(images)
            loss = criterion(y_preds, labels)

            # Scale the loss to the mean of the accumulated batch size
            avg_loss += loss.item() / len(train_loader) 
            loss = loss / accumulation_steps
            loss.backward()
            if (i - 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            
                scheduler.step()
        
        elapsed = time.time() - start_time
        LOGGER.debug(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} time: {elapsed:.0f}s')

In [None]:
torch.save(model.state_dict(), f'SnakeCLEF2021-ViT_large_patch16-384-FT-FL-OCLR-20E.pth')