In [None]:
import warnings
warnings.filterwarnings('ignore')

from glob import glob
import pandas as pd
import numpy as np 
from tqdm import tqdm
import cv2

import os
import timm
import random

import albumentations as A
from albumentations.pytorch import transforms, ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from sklearn.metrics import f1_score, accuracy_score
import json

import wandb

In [None]:
# !pip3 install timm wandb albumentations

In [None]:
# Configs
config = {}
config_path = "./config/efficient_base.json"
with open(config_path, 'r') as f:
    config = json.load(f)
    f.close()

In [None]:
config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
wandb.init(project=config["WANDB_PROJECT"],
           config=config,
           job_type="Train",
           anonymous='must')
wandb.run.name = config['WANDB_NAME']

In [None]:
if not os.path.exists(config["MODEL_SAVE"]):
    os.makedirs(config["MODEL_SAVE"], exist_ok=True)

In [None]:
def setSeeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
setSeeds(config["SEED"])

In [None]:
torch.cuda.empty_cache()

In [None]:
class CustomDataset(Dataset):
    def __init__(self,
                 data_path,
                 size,
                 transform=None,
                 fold=0,
                 mode="train"):
        self.csv = pd.read_csv(data_path)
        if 'kfold' in self.csv:
            if mode == "train":
                self.csv = self.csv[self.csv['kfold'] != fold]
            elif mode == "validation":
                self.csv = self.csv[self.csv['kfold'] == fold]
        
        self.path = self.csv['path'].to_list()
        if 'encoded_label' in self.csv:
            self.labels = self.csv['encoded_label'].to_list()
        self.transform = transform
        self.size = size
        self.mode = mode
    
    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, idx):        
        # Image
        image = cv2.imread(self.path[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(self.size)(image=image)['image']
        
        # Only test mode
        if self.mode == "test":
            return {
                'image': image
            }
        
        # Label
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long)
        
        return {
            'image': image,
            'label': label
        }

In [None]:
def create_train_transforms(size):
    return A.Compose([
        A.Resize(size, size),
        A.HorizontalFlip(p=0.4),
        A.VerticalFlip(p=0.4),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1, 0.1),
            contrast_limit=(-0.1, 0.1),
            p=0.3
        ),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ShiftScaleRotate(
            p=0.4,
            shift_limit=(-0.05, 0.05),
            scale_limit=(-0.3, 0.05),
            rotate_limit=(-90, -90),
            interpolation=4,
            border_mode=4,
        ),
        A.Cutout(p=0.3,
               num_holes=15, 
               max_h_size=8,
               max_w_size=8
        ),
        ToTensorV2()
    ])

In [None]:
def create_validation_transforms(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

In [None]:
train_dataset = CustomDataset(
    data_path = config["TRAIN_CSV"],
    size = config["SIZE"],
    transform = create_train_transforms,
    fold = config["FOLD"],
    mode = "train"
)
validation_dataset = CustomDataset(
    data_path = config["TRAIN_CSV"],
    size = config["SIZE"],
    transform = create_validation_transforms,
    fold = config["FOLD"],
    mode = "validation"
)
train_loader = DataLoader(
    dataset = train_dataset,
    shuffle = True,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)
validation_loader = DataLoader(
    dataset = validation_dataset,
    shuffle = False,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)

In [None]:
def load_model(model_name = None, pretrained = True, num_classes = 88):
    return timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
model = load_model(config["MODEL"], config["PRETRAINED"], config["N_CLASSES"]).to(config["DEVICE"])

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduce=True, smooth=0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce
        self.smooth = smooth

    def forward(self, inputs, targets):
        
        if self.smooth != 0:
            targets = (1-self.smooth) * targets + self.smooth / inputs.size(1)
            
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        pt = torch.exp(-BCE_loss)
        focal_term = (1-pt).pow(self.gamma)
        F_loss = self.alpha * focal_term * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config["LEARNING_RATE"])
criterion = nn.CrossEntropyLoss()
criterion2 = FocalLoss()
scaler = torch.cuda.amp.GradScaler()

In [None]:
def score_function(real, pred):
    score = f1_score(real, pred, average="macro")
    return score

In [None]:
best_loss, best_pred = float('inf'), 0
pre_epoch = -1
if config["RESUME"]:
    last_model = f"{config['MODEL_SAVE_PREFIX']}_last.pth"
    best_model = f"{config['MODEL_SAVE_PREFIX']}_best.pth"
    last_model_path = os.path.join(config['MODEL_SAVE'], last_model)
    best_model_path = os.path.join(config['MODEL_SAVE'], best_model)
    
    if os.path.exists(last_model) and os.path.exists(best_model):
        model_data = torch.load(best_model_path)
        best_pred = model_data['score']
        best_loss = model_data['loss']

        model_data = torch.load(last_model_path)
        pre_epoch = model_data['epoch']
        model.load_state_dict(model_data['state_dict'])
pre_epoch += 1

In [None]:
wandb.watch(model)

for epoch in range(pre_epoch, pre_epoch + config["EPOCHS"]):
    train_loss = 0
    total_train_loss = 0
    train_data_cnt = 0
    train_pred=[]
    train_y=[]
    model.train()
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, batch in pbar:
        optimizer.zero_grad()
        x = torch.tensor(batch['image'], dtype=torch.float32, device=config["DEVICE"])
        y = torch.tensor(batch['label'], dtype=torch.long, device=config["DEVICE"])
        with torch.cuda.amp.autocast():
            pred = model(x)
        ans = torch.zeros((len(y), config["N_CLASSES"]), device=config["DEVICE"])
        for idx, x in enumerate(y):
            ans[idx][x] = 1
        loss = criterion(pred, y) + criterion2(pred, ans)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_train_loss += loss.item() / len(train_loader)
        train_loss = train_loss * train_data_cnt + loss.item()
        train_data_cnt += 1
        train_loss /= train_data_cnt
        train_pred += pred.argmax(1).detach().cpu().numpy().tolist()
        train_y += y.detach().cpu().numpy().tolist()
        pbar.set_postfix({
            "epoch": f"{epoch}/{pre_epoch + config['EPOCHS']}",
            "train_loss" : f"{train_loss:.5f}",
            "total_train_loss": f"{total_train_loss:.5f}"
        })
        if step % 5 == 0:
            wandb.log({
                'train_loss': train_loss
            })
        
    pbar.close()
    train_f1 = score_function(train_y, train_pred)
    wandb.log({
        'train_score': train_f1
    })
    print(f'TRAIN\tf1 : {train_f1:.5f}')
    
    validation_loss = 0
    total_validation_loss = 0
    validation_data_cnt = 0
    validation_pred=[]
    validation_y=[]
    model.eval()
    pbar = tqdm(enumerate(validation_loader), total=len(validation_loader))
    with torch.no_grad():
        for step, batch in pbar:
            x = torch.tensor(batch['image'], dtype=torch.float32, device=config["DEVICE"])
            y = torch.tensor(batch['label'], dtype=torch.long, device=config["DEVICE"])
            with torch.cuda.amp.autocast():
                pred = model(x)
            ans = torch.zeros((len(y), config["N_CLASSES"]), device=config["DEVICE"])
            for idx, x in enumerate(y):
                ans[idx][x] = 1
            loss = criterion(pred, y) + criterion2(pred, ans)
            total_validation_loss += loss.item() / len(validation_loader)
            validation_loss = validation_loss * validation_data_cnt + loss.item()
            validation_data_cnt += 1
            validation_loss /= validation_data_cnt
            validation_pred += pred.argmax(1).detach().cpu().numpy().tolist()
            validation_y += y.detach().cpu().numpy().tolist()
            pbar.set_postfix({
                "epoch": f"{epoch}/{pre_epoch + config['EPOCHS']}",
                "val_loss" : f"{validation_loss:.5f}",
                "total_val_loss": f"{total_validation_loss:.5f}"
            })
            if step % 5 == 0:
                wandb.log({
                    'val_loss': validation_loss
                })
        
        pbar.close()
    val_f1 = score_function(validation_y, validation_pred)
    wandb.log({
        'val_score': val_f1
    })
    print(f'VAL\tf1 : {val_f1:.5f}')
    
    # Update axbout Loss
    if best_loss > validation_loss:
        best_loss = validation_loss
        torch.save({
            "epoch": epoch,
            "loss": validation_loss,
            "score": val_f1,
            "state_dict": model.state_dict()
        }, f"{config['MODEL_SAVE']}/{config['MODEL_SAVE_PREFIX']}_best.pth")
    
    torch.save({
        "epoch": epoch,
        "loss": validation_loss,
        "score": val_f1,
        "state_dict": model.state_dict()
    }, f"{config['MODEL_SAVE']}/{config['MODEL_SAVE_PREFIX']}_last.pth")

In [None]:
wandb.finish()