In [1]:
import cv2
import os
import timm
import numpy as np
import pandas as pd
import albumentations as A
import logging

from glob import glob
from tqdm import tqdm
from easydict import EasyDict
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import f1_score

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

# Custom Dataset

In [2]:
class CustomDataset(Dataset):
    def __init__(self, img_list, label_list=None, transforms=None, mode="train") :
        self.img_list = img_list
        
        if mode == "train" : 
            self.label_list = self.label_encoder(label_list)
            
        self.transforms = transforms
        self.mode = mode
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_path = self.img_list[idx]
        
        img = cv2.imread(img_path)
        
        try :
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except :
            print(img_path)
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            print(img)
            print(img.shape)
            
        if self.transforms:            
            img = self.transforms(image=img)['image']
        
        if self.mode == "train" :
            label = self.label_list[idx]
            return img, torch.tensor(label)
        
        elif self.mode == "test" :
            return img
    
    def label_encoder(self, label_list) :
        label_enc = {k : i for i, k in enumerate(sorted(list(set(label_list))))}
#         display(label_enc)
        return [label_enc[label] for label in label_list]

#### test code

In [3]:
# df = pd.read_csv('../data/aug_train_df.csv')
# transforms = A.Compose([
#     A.Resize(224,224),
#     A.Normalize(),
#     A.Rotate(),
#     ToTensorV2()
# ])
# db = CustomDataset(list(df['file_name']), list(df['label']), transforms, mode="train")
# db_loader = DataLoader(db, batch_size=16, shuffle=True)
# for img, label in db_loader : 
#     print(img.shape)
#     print(label.shape)
#     print(label)
#     break

# Focal Loss

In [4]:
class FocalLoss(nn.Module) :
    def __init__(self, alpha=2, gamma=2, logits=False, reduction='none') :
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction

    def forward(self, inputs, targets) :
        ce_loss = nn.CrossEntropyLoss(reduction=self.reduction)(inputs, targets)
        pt = torch.exp(-ce_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduction :
            return torch.mean(F_loss)
        else :
            return F_loss

# Custom SwinTransformer

In [5]:
class BackBone(nn.Module) :
    def __init__(self, model_name, backbone_output) :
        super(BackBone, self).__init__()
        self.model = timm.create_model(model_name=model_name, num_classes=backbone_output, pretrained=True)
    
    def forward(self, x) :
        output = self.model(x)
        return output
    
class MLP(nn.Module) :
    def __init__(self, in_features, dropout_rate, num_state) :
        super(MLP, self).__init__()
        #forward_features 시 LayerNorm까지 통과한 결과임
        # 따라서 LayerNorm 와 AdaptiveAvgPool1d는 필요없음    

        self.linear_1 = nn.Linear(in_features, in_features//2, bias=True)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(p=dropout_rate, inplace=False)
        self.linear_2 = nn.Linear(in_features//2, num_state, bias=True)
        
    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.linear_2(x)
        return x
    
class CustomSwinTransformer(nn.Module) :
    def __init__(self, 
                 model_path, 
                 model_name, 
                 backbone_output, 
                 num_class, 
                 num_state,
                 label_decoder,
                 ensemble_backbone,
                 ensemble_model_list,
                 dropout_rate=0.5) :
        super(CustomSwinTransformer, self).__init__()
        self.label_decoder = label_decoder
        self.ensembel_backbone = ensemble_backbone
        
        if self.ensembel_backbone :
            self.backbones = self.get_ensemble_backbone(ensemble_model_list,
                                                       model_name,
                                                       backbone_output)
        else :
            self.backbone = self.get_backbone(model_path,
                                             model_name,
                                             backbone_output)
        
        # num_state + 1을 해준 이유 = None Class를 추가할 예정이기 때문
        self.mlps = nn.ModuleList([MLP(in_features=1024, 
                         dropout_rate=dropout_rate, 
#                        num_state = num_state[i]) for i in range(num_class)])
                         num_state = num_state[i] + 1) for i in range(num_class)])
        
    def forward(self, x) :       
#         prob = self.backbone(x)
#         prob = F.softmax(prob.cpu())
#         pred = torch.argmax(prob, dim=1)
#         pred = list(map(lambda x : self.label_decoder[x.item()], pred))
        
        if self.ensembel_backbone :
            feature_map = 0
            for backbone in self.backbones :
                feature_map += backbone.forward_features(x)
            feature_map /= len(self.backbones)
        else :
            feature_map = self.backbone.forward_features(x)

        preds = []    
        for mlp in self.mlps :
            preds.append(mlp(feature_map))
        return preds
    
    def WeightFreeze(self, model) :
        for i, child in enumerate(model.children()) :
            for param in child.parameters() :
                param.requires_grad = False
        return model
    
    def get_backbone(self, model_path, model_name, backbone_output) :
        checkpoint = torch.load(model_path)
        backbone = BackBone(model_name, backbone_output)
        backbone.load_state_dict(checkpoint["model_state_dict"])
        backbone = self.WeightFreeze(backbone.model)
#         return backbone.model
        return backbone
    
    def get_ensemble_backbone(self, ensemble_model_list,  model_name, backbone_output) :
        backbones = []
        for model_path in ensemble_model_list:
            backbones.append(self.get_backbone(model_path, model_name, backbone_output))
        return nn.ModuleList(backbones)
    


# Training

In [6]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix(imgs, labels):
    lam = np.random.beta(1.0, 1.0)
    rand_index = torch.randperm(imgs.size()[0]).cuda()
    target_a = labels
    target_b = labels[rand_index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(imgs.size(), lam)
    imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2]

    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (imgs.size()[-1] * imgs.size()[-2]))

    return imgs, target_a, target_b, lam

def mixup(imgs, labels, alpha=1.0) :
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(imgs.size()[0]).cuda()
    mixed_imgs = lam * imgs + (1 - lam) * imgs[rand_index, :]
    target_a, target_b = labels, labels[rand_index]
    
    return mixed_imgs, lam, target_a, target_b, rand_index

def accuracy_function(real, pred):    
    real = real.cpu()
    pred = torch.argmax(pred, dim=1).cpu()
    score = f1_score(real, pred, average='macro')
    return score

def mlp_label_split(num_state, labels, specific_index=None) :
    tmp = {i : torch.tensor([], dtype=torch.int32) for i in range(15)}
    specific_tmp = {i : torch.tensor([], dtype=torch.int32) for i in range(15)}
    for idx, label in enumerate(labels) :
            
        if 0 <= label and label < num_state[0] :
            spec_label = torch.tensor([label])
        else :
            spec_label = torch.tensor([num_state[0]])

            
        tmp[0] = torch.cat((tmp[0], spec_label), dim=0)
        
        if idx == specific_index :
            specific_tmp[0] = spec_label
            
        for i in range(1, 15):
            if sum(num_state[:i]) <= label and label < sum(num_state[:i+1]) :
                spec_label = torch.tensor([int(label - sum(num_state[:i]))])
                
            else :
                spec_label =  torch.tensor([num_state[i]])

            tmp[i] = torch.cat((tmp[i], spec_label), dim=0)
            
            if idx == specific_index :
                specific_tmp[i] = spec_label
            
    return tmp, specific_tmp
    
    
def training(model, train_loader, valid_loader, opt) :
        
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr= opt.learning_rate,
                                 weight_decay=opt.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, 
                                  T_max=opt.cosine_lr_Tmax, 
                                  eta_min=opt.cosine_lr_eta_min)
    
#     criterion = nn.CrossEntropyLoss()
    criterion = FocalLoss()
    
    
    if opt.resume : 
        model_checkpoint = torch.load(opt.model_path)
        model.load_state_dict(model_checkpoint["model_state_dict"])
        optimizer.load_state_dict(model_checkpoint["optimizer_state_dict"])
        scheduler = CosineAnnealingLR(optimizer, 
                                  T_max=opt.cosine_lr_Tmax, 
                                  eta_min=opt.cosine_lr_eta_min)
        
        opt.start_epoch = model_checkpoint["epoch"]
    else : 
        opt.start_epoch = 0
        
    early_stop_step = 0
    best_loss = 10
    for E in range(opt.start_epoch + 1, opt.epochs + 1) :
            # train
            tqdm_train = tqdm(train_loader)
            train_loss, train_f1 = 0, 0
            for batch, (img, label) in enumerate(tqdm_train, start=1) :
                model.train()
                optimizer.zero_grad()
                
                img = img.to(opt.device)
                label = label.to(opt.device)
                
                if opt.cutmix :
                    imgs, target_a, target_b, lam = cutmix(img, label)
                    outputs = model(imgs)
                    loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam)
                
                elif opt.mixup :
                    mixed_imgs, lam, target_a, target_b, rand_index = mixup(img, label)
                    
                    outputs = model(mixed_imgs)
                    split_label, specific_label = mlp_label_split(opt.num_state, label)
                    loss = 0
                    score = 0
                    for idx, output in enumerate(outputs) :
                        loss += criterion(output, split_label[idx].to(opt.device)) * lam + criterion(output, specific_label[idx]) * (1. - lam)
                        score += accuracy_function(split_label[idx].to(opt.device), output)
                
                else:
                    outputs = model(img)
                    split_label, _ = mlp_label_split(opt.num_state, label)
                    
                    loss = 0
                    score = 0
                    for idx, output in enumerate(outputs) :
                        loss += criterion(output, split_label[idx].to(opt.device))
                        score += accuracy_function(split_label[idx].to(opt.device), output)
        
                loss.backward()
                optimizer.step()
                
                
                train_loss += loss.item()
                train_f1 += (score / len(opt.num_state))
                tqdm_train.set_postfix({"Epoch" : E,
                                "Mean train loss" : "{:06f}".format(train_loss / (batch)),
                                "Mean train f1" : "{:06f}".format(train_f1 / (batch))
                               })
            # validation
            tqdm_valid = tqdm(valid_loader)
            valid_loss, valid_f1 = 0, 0
            for batch, (img, label) in enumerate(tqdm_valid, start=1) :
                model.eval()
                
                img = img.to(opt.device)
                label = label.to(opt.device)
                
                with torch.no_grad() :
#                     output = model(img)
#                     loss = criterion(output, label)
                    outputs = model(img)
                    split_label = mlp_label_split(opt.num_state, label)
                    loss = 0
                    score = 0
                    for idx, output in enumerate(outputs) :
                        loss += criterion(output, split_label[idx].to(opt.device))
                        score += accuracy_function(split_label[idx].to(opt.device), output)
                    
#                 score = accuracy_function(label, output)
                valid_loss += loss.item()
                valid_f1 += (score / len(opt.num_state))
                tqdm_valid.set_postfix({
                    "Mean valid loss": "{:06f}".format(valid_loss / (batch)),
                    "Mean valid f1": "{:06f}".format(valid_f1 / (batch))
                    })
            
            # scheduler
            scheduler.step()
            
            mean_valid_loss = valid_loss / batch
            if mean_valid_loss < best_loss :
                early_stop_step = 0
                best_loss = mean_valid_loss
                torch.save({
                    "epoch" : E,
                    "model_state_dict" : model.state_dict(),
                    "optimizer_state_dict" : optimizer.state_dict()
                }, 
                           os.path.join(opt.save_path, f'{E}E_{mean_valid_loss:0.4f}_{opt.model_name}.pt'))
            
            elif mean_valid_loss > best_loss : 
                early_stop_step += 1
                print(f"Early Stopping Step : [{early_stop_step} / {opt.early_stopping}]")
            
            if early_stop_step == opt.early_stopping :
                print("=== Early Stop ===")
                break

# Weight Freeze

In [7]:
def WeightFreeze(model) :
    if isinstance(model, list) :
        for ensemble in model :
            for i, child in enumerate(ensemble.backbone.children()) :
                for param in child.parameters() :
                    param.requires_grad = False
        return model
    
    else :
        for i, child in enumerate(model.backbone.children()) :
            for param in child.parameters() :
                param.requires_grad = False
        return model

# Label Decoder

In [8]:
def label_decoder(labels) :
    return {k:i for i, k in enumerate(labels)}

# Option logging

In [9]:
def create_log(save_path, name=None) :
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    
    file_handler = logging.FileHandler(filename=os.path.join(save_path, "train.log"))

    # formatter 객체 생성
    formatter = logging.Formatter(fmt="%(asctime)s - %(message)s", 
                                  datefmt='%Y-%d-%m %H:%M:%S')

    # handler에 level 설정
    file_handler.setLevel(logging.DEBUG)

    # handler에 format 설정
    file_handler.setFormatter(formatter)
    
    # logger에 handler 추가
    logger.addHandler(file_handler)
    
    return logger

In [None]:
opt = {
    "df_path" : "../data/aug_v4_train_df.csv",
    "save_path" : "../model/custom_ensemble_swin_aug_v4-5_CEL-mixup",
    "model_name" : "swin_base_patch4_window7_224_in22k",
    "num_classes" : 88,
    'num_state' : [4, 9, 6, 6, 6, 5, 6, 5, 8, 6, 6, 2, 5, 6, 8],
    "resize" : 224,
    "device" : "cuda:0",
    "early_stopping" : 5,
    "epochs" : 30,
    "batch_size" : 128,
    "learning_rate" : 1e-4,
    "weight_decay" : 0.01,
    "cosine_lr_Tmax" : 20,
    "cosine_lr_eta_min" : 1e-5,
    "cutmix" : False,
    "mixup" : False,
    "resume" : False,
    "model_path" : "../model/swin_aug_v4_mixup/21E_0.0382_swin_base_patch4_window7_224_in22k.pt",
    "logging" : True
}
opt = EasyDict(opt)
os.makedirs(opt.save_path, exist_ok=True)

model_opt = {    
    'model_path' : '../model/swin_aug_v4_mixup/19E_0.0394_swin_base_patch4_window7_224_in22k.pt',
    'model_name' : 'swin_base_patch4_window7_224_in22k',
    'backbone_output' : 88,
    'num_class' : 15,
    'num_state' : [4, 9, 6, 6, 6, 5, 6, 5, 8, 6, 6, 2, 5, 6, 8],
    "ensemble_backbone" : True,
    "ensemble_model_list" : [
        "../model/ensemble_aug_v4-5_CEL-mixup/19E_0.0394_swin_base_patch4_window7_224_in22k.pt",
        "../model/ensemble_aug_v4-5_CEL-mixup/22E_0.0269_swin_base_patch4_window7_224_in22k.pt",
        "../model/ensemble_aug_v4-5_CEL-mixup/30E_0.0114_swin_base_patch4_window7_224_in22k.pt"
    ],
    'dropout_rate' : 0.5,

}
model_opt = EasyDict(model_opt)

#  option logging
log = create_log(opt.save_path)
log.info(opt)
log.info(model_opt)

t_transforms = A.Compose([
    A.Normalize(),
    A.Resize(opt.resize, opt.resize),
    A.Blur(p=0.7),#blur_limit=(7, 7), p=0.7),
    A.Rotate(limit=(45), p=1),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip()
    ], p=1),
    ToTensorV2()
])

v_transforms = A.Compose([
    A.Normalize(),
    A.Resize(opt.resize, opt.resize),
    ToTensorV2()
])

train_df = pd.read_csv(opt.df_path)
t_imgs, v_imgs, t_labels, v_labels = train_test_split(
    list(train_df['file_name']),
    list(train_df['label']),
    train_size=0.8,
    shuffle=True,
    random_state=51,
    stratify=list(train_df['label']))

model_opt.label_decoder = label_decoder(list(sorted(train_df['label'].unique())))

train_data = CustomDataset(t_imgs, t_labels, t_transforms)
valid_data = CustomDataset(v_imgs, v_labels, v_transforms)

train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=opt.batch_size, shuffle=True)


# custom_swin = CustomSwinTransformer(**model_opt).to(opt.device)
# model = WeightFreeze(custom_swin)
model = CustomSwinTransformer(**model_opt).to(opt.device)
        
training(model, train_loader, valid_loader, opt)

print("==== Complete ====")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  7%|█▌                   | 14/194 [01:23<17:40,  5.89s/it, Epoch=1, Mean train loss=28.253568, Mean train f1=0.227283]