In [1]:
import warnings
warnings.filterwarnings('ignore')

from glob import glob
import pandas as pd
import numpy as np 
from tqdm import tqdm
import cv2

import os
import timm
import random

import albumentations as A
from albumentations.pytorch import transforms, ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.transforms as transforms

from sklearn.metrics import f1_score, accuracy_score

In [2]:
# !pip3 install timm albumentations

In [3]:
# Configs
config = {
    "SEED": 777,
    "SIZE": 512,
    "CSV": "./train_kfold.csv",
    
    "FOLD": 0,
    "BATCH_SIZE": 8,
    "LEARNING_RATE": 0.001,
    "EPOCHS": 30,
    "N_WORKERS": 4,
    
    "MODEL": "tf_efficientnet_b7",
    "MODEL_SAVE": "./b7_model",
    "MODEL_SAVE_PREFIX": "b7_777_",
    
    "DEVICE": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

In [4]:
if not os.path.exists(config["MODEL_SAVE"]):
    os.makedirs(config["MODEL_SAVE"], exist_ok=True)

In [5]:
random.seed(config["SEED"])
torch.cuda.manual_seed(config["SEED"])
torch.manual_seed(config["SEED"])
torch.cuda.empty_cache()

In [6]:
class CustomDataset(Dataset):
    def __init__(self,
                 data_path,
                 size,
                 transform=None,
                 fold=0,
                 mode="train"):
        self.csv = pd.read_csv(data_path)
        if mode == "train":
            self.csv = self.csv[self.csv['kfold'] != fold]
        elif mode == "validation":
            self.csv = self.csv[self.csv['kfold'] == fold]
        
        self.path = self.csv['path'].to_list()
        self.labels = self.csv['encoded_label'].to_list()
        self.transform = transform
        self.size = size
        self.mode = mode
    
    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, idx):        
        # Image
        image = cv2.imread(self.path[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(self.size)(image=image)['image']
        
        # Only test mode
        if self.mode == "test":
            return {
                'image': image
            }
        
        # Label
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long)
        
        return {
            'image': image,
            'label': label
        }

In [7]:
def create_train_transforms(size=512):
    return A.Compose([
        A.Resize(size, size),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1, 0.1),
            contrast_limit=(-0.1, 0.1),
            p=0.3
        ),
        ToTensorV2()
    ])

In [8]:
def create_validation_transforms(size=128):
    return A.Compose([
        A.Resize(size, size),
        ToTensorV2()
    ])

In [9]:
train_dataset = CustomDataset(
    data_path = config["CSV"],
    size = config["SIZE"],
    transform = create_train_transforms,
    fold = config["FOLD"],
    mode = "train"
)
validation_dataset = CustomDataset(
    data_path = config["CSV"],
    size = config["SIZE"],
    transform = create_validation_transforms,
    fold = config["FOLD"],
    mode = "validation"
)
train_loader = DataLoader(
    dataset = train_dataset,
    shuffle = True,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)
validation_loader = DataLoader(
    dataset = validation_dataset,
    shuffle = False,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)

In [10]:
# Create Model (Using timm library)
class CustomModel(nn.Module):
    def __init__(self, model_name = None, pretrained = True, num_classes = 88):
        super(CustomModel, self).__init__()
        self.model = timm.create_model(model_name,
                                       pretrained = pretrained,
                                       num_classes = num_classes)
    def forward(self, x):
        return self.model(x)

In [11]:
model = CustomModel(config["MODEL"]).to(config["DEVICE"])

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=config["LEARNING_RATE"])
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

In [13]:
def score_function(real, pred):
    score = f1_score(real, pred, average="macro")
    return score

In [14]:
best_loss, best_pred = float('inf'), 0
for epoch in range(config["EPOCHS"]):
    train_loss = 0
    total_train_loss = 0
    train_data_cnt = 0
    train_pred=[]
    train_y=[]
    model.train()
    pbar = tqdm(train_loader, total=len(train_loader))
    for batch in pbar:
        optimizer.zero_grad()
        x = torch.tensor(batch['image'], dtype=torch.float32, device=config["DEVICE"])
        y = torch.tensor(batch['label'], dtype=torch.long, device=config["DEVICE"])
        with torch.cuda.amp.autocast():
            pred = model(x)
        loss = criterion(pred, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_train_loss += loss.item() / len(train_loader)
        train_loss = train_loss * train_data_cnt + loss.item()
        train_data_cnt += 1
        train_loss /= train_data_cnt
        train_pred += pred.argmax(1).detach().cpu().numpy().tolist()
        train_y += y.detach().cpu().numpy().tolist()
        pbar.set_postfix({
            "epoch": f"{epoch}/{config['EPOCHS']}",
            "train_loss" : f"{train_loss:.5f}",
            "total_train_loss": f"{total_train_loss:.5f}"
        })
        
    pbar.close()
    train_f1 = score_function(train_y, train_pred)
    print(f'TRAIN\tf1 : {train_f1:.5f}')
    
    validation_loss = 0
    total_validation_loss = 0
    validation_data_cnt = 0
    validation_pred=[]
    validation_y=[]
    model.eval()
    pbar = tqdm(validation_loader, total=len(validation_loader))
    with torch.no_grad():
        for batch in pbar:
            x = torch.tensor(batch['image'], dtype=torch.float32, device=config["DEVICE"])
            y = torch.tensor(batch['label'], dtype=torch.long, device=config["DEVICE"])
            with torch.cuda.amp.autocast():
                pred = model(x)
            loss = criterion(pred, y)
            total_validation_loss += loss.item() / len(validation_loader)
            validation_loss = validation_loss * validation_data_cnt + loss.item()
            validation_data_cnt += 1
            validation_loss /= validation_data_cnt
            validation_pred += pred.argmax(1).detach().cpu().numpy().tolist()
            validation_y += y.detach().cpu().numpy().tolist()
            pbar.set_postfix({
                "epoch": f"{epoch}/{config['EPOCHS']}",
                "val_loss" : f"{validation_loss:.5f}",
                "total_val_loss": f"{total_validation_loss:.5f}"
            })
        pbar.close()
    val_f1 = score_function(validation_y, validation_pred)
    print(f'VAL\tf1 : {val_f1:.5f}')
    
    # Update about Loss
    if best_loss > validation_loss:
        best_loss = validation_loss
        torch.save({
            "epoch": epoch,
            "loss": validation_loss,
            "score": val_f1,
            "state_dict": model.state_dict()
        }, f"{config['MODEL_SAVE']}/{config['MODEL_SAVE_PREFIX']}_best.pth")
    
    torch.save({
        "epoch": epoch,
        "loss": validation_loss,
        "score": val_f1,
        "state_dict": model.state_dict()
    }, f"{config['MODEL_SAVE']}/{config['MODEL_SAVE_PREFIX']}_{epoch}.pth")
    
    # Update about Score
    #if best_score < val_f1:
    #    best_score = val_f1

100%|██████████| 401/401 [01:43<00:00,  3.89it/s, epoch=0/30, train_loss=1.34686, total_train_loss=1.34686]


TRAIN	f1 : 0.15419


100%|██████████| 134/134 [00:07<00:00, 17.25it/s, epoch=0/30, val_loss=0.91163, total_val_loss=0.91163]


VAL	f1 : 0.17926


100%|██████████| 401/401 [01:41<00:00,  3.93it/s, epoch=1/30, train_loss=0.80049, total_train_loss=0.80049]


TRAIN	f1 : 0.22162


100%|██████████| 134/134 [00:07<00:00, 17.28it/s, epoch=1/30, val_loss=0.92063, total_val_loss=0.92063]


VAL	f1 : 0.23628


100%|██████████| 401/401 [01:42<00:00,  3.93it/s, epoch=2/30, train_loss=0.68992, total_train_loss=0.68992]


TRAIN	f1 : 0.26333


100%|██████████| 134/134 [00:07<00:00, 17.28it/s, epoch=2/30, val_loss=0.68197, total_val_loss=0.68197]


VAL	f1 : 0.26061


100%|██████████| 401/401 [01:42<00:00,  3.93it/s, epoch=3/30, train_loss=0.60430, total_train_loss=0.60430]


TRAIN	f1 : 0.28867


100%|██████████| 134/134 [00:07<00:00, 17.31it/s, epoch=3/30, val_loss=0.49939, total_val_loss=0.49939]


VAL	f1 : 0.32975


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=4/30, train_loss=0.57763, total_train_loss=0.57763]


TRAIN	f1 : 0.33566


100%|██████████| 134/134 [00:07<00:00, 17.24it/s, epoch=4/30, val_loss=0.65068, total_val_loss=0.65068]


VAL	f1 : 0.29698


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=5/30, train_loss=0.53747, total_train_loss=0.53747]


TRAIN	f1 : 0.37148


100%|██████████| 134/134 [00:07<00:00, 17.19it/s, epoch=5/30, val_loss=0.54519, total_val_loss=0.54519]


VAL	f1 : 0.29062


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=6/30, train_loss=0.48716, total_train_loss=0.48716]


TRAIN	f1 : 0.38775


100%|██████████| 134/134 [00:07<00:00, 17.14it/s, epoch=6/30, val_loss=0.50976, total_val_loss=0.50976]


VAL	f1 : 0.39684


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=7/30, train_loss=0.46864, total_train_loss=0.46864]


TRAIN	f1 : 0.43350


100%|██████████| 134/134 [00:07<00:00, 17.27it/s, epoch=7/30, val_loss=0.66389, total_val_loss=0.66389]


VAL	f1 : 0.31340


100%|██████████| 401/401 [01:42<00:00,  3.93it/s, epoch=8/30, train_loss=0.40514, total_train_loss=0.40514]


TRAIN	f1 : 0.47549


100%|██████████| 134/134 [00:07<00:00, 17.20it/s, epoch=8/30, val_loss=0.48045, total_val_loss=0.48045]


VAL	f1 : 0.40200


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=9/30, train_loss=0.34830, total_train_loss=0.34830]


TRAIN	f1 : 0.53136


100%|██████████| 134/134 [00:07<00:00, 17.19it/s, epoch=9/30, val_loss=0.55695, total_val_loss=0.55695]


VAL	f1 : 0.40268


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=10/30, train_loss=0.39081, total_train_loss=0.39081]


TRAIN	f1 : 0.52352


100%|██████████| 134/134 [00:07<00:00, 17.15it/s, epoch=10/30, val_loss=0.52171, total_val_loss=0.52171]


VAL	f1 : 0.40660


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=11/30, train_loss=0.34400, total_train_loss=0.34400]


TRAIN	f1 : 0.55925


100%|██████████| 134/134 [00:07<00:00, 17.13it/s, epoch=11/30, val_loss=0.41201, total_val_loss=0.41201]


VAL	f1 : 0.46472


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=12/30, train_loss=0.27969, total_train_loss=0.27969]


TRAIN	f1 : 0.60976


100%|██████████| 134/134 [00:07<00:00, 17.16it/s, epoch=12/30, val_loss=0.54466, total_val_loss=0.54466]


VAL	f1 : 0.46711


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=13/30, train_loss=0.27098, total_train_loss=0.27098]


TRAIN	f1 : 0.64939


100%|██████████| 134/134 [00:07<00:00, 17.24it/s, epoch=13/30, val_loss=0.78520, total_val_loss=0.78520]


VAL	f1 : 0.41899


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=14/30, train_loss=0.27039, total_train_loss=0.27039]


TRAIN	f1 : 0.63919


100%|██████████| 134/134 [00:07<00:00, 17.22it/s, epoch=14/30, val_loss=0.42249, total_val_loss=0.42249]


VAL	f1 : 0.52855


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=15/30, train_loss=0.22743, total_train_loss=0.22743]


TRAIN	f1 : 0.68678


100%|██████████| 134/134 [00:07<00:00, 17.33it/s, epoch=15/30, val_loss=0.50054, total_val_loss=0.50054]


VAL	f1 : 0.48673


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=16/30, train_loss=0.20240, total_train_loss=0.20240]


TRAIN	f1 : 0.71934


100%|██████████| 134/134 [00:07<00:00, 17.14it/s, epoch=16/30, val_loss=0.66843, total_val_loss=0.66843]


VAL	f1 : 0.50246


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=17/30, train_loss=0.21321, total_train_loss=0.21321]


TRAIN	f1 : 0.71359


100%|██████████| 134/134 [00:07<00:00, 17.30it/s, epoch=17/30, val_loss=0.54571, total_val_loss=0.54571]


VAL	f1 : 0.49763


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=18/30, train_loss=0.20076, total_train_loss=0.20076]


TRAIN	f1 : 0.72347


100%|██████████| 134/134 [00:07<00:00, 17.18it/s, epoch=18/30, val_loss=0.41422, total_val_loss=0.41422]


VAL	f1 : 0.56469


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=19/30, train_loss=0.16634, total_train_loss=0.16634]


TRAIN	f1 : 0.76667


100%|██████████| 134/134 [00:07<00:00, 17.29it/s, epoch=19/30, val_loss=0.52054, total_val_loss=0.52054]


VAL	f1 : 0.54174


100%|██████████| 401/401 [01:42<00:00,  3.93it/s, epoch=20/30, train_loss=0.17071, total_train_loss=0.17071]


TRAIN	f1 : 0.78543


100%|██████████| 134/134 [00:07<00:00, 17.22it/s, epoch=20/30, val_loss=0.58655, total_val_loss=0.58655]


VAL	f1 : 0.49989


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=21/30, train_loss=0.17990, total_train_loss=0.17990]


TRAIN	f1 : 0.76334


100%|██████████| 134/134 [00:07<00:00, 17.31it/s, epoch=21/30, val_loss=0.47552, total_val_loss=0.47552]


VAL	f1 : 0.56233


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=22/30, train_loss=0.12582, total_train_loss=0.12582]


TRAIN	f1 : 0.84249


100%|██████████| 134/134 [00:07<00:00, 17.30it/s, epoch=22/30, val_loss=0.48031, total_val_loss=0.48031]


VAL	f1 : 0.56541


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=23/30, train_loss=0.11868, total_train_loss=0.11868]


TRAIN	f1 : 0.86765


100%|██████████| 134/134 [00:07<00:00, 17.30it/s, epoch=23/30, val_loss=0.47270, total_val_loss=0.47270]


VAL	f1 : 0.58803


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=24/30, train_loss=0.11772, total_train_loss=0.11772]


TRAIN	f1 : 0.86321


100%|██████████| 134/134 [00:07<00:00, 17.10it/s, epoch=24/30, val_loss=0.49680, total_val_loss=0.49680]


VAL	f1 : 0.53308


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=25/30, train_loss=0.15527, total_train_loss=0.15527]


TRAIN	f1 : 0.81332


100%|██████████| 134/134 [00:07<00:00, 17.23it/s, epoch=25/30, val_loss=0.51503, total_val_loss=0.51503]


VAL	f1 : 0.54169


100%|██████████| 401/401 [01:42<00:00,  3.93it/s, epoch=26/30, train_loss=0.07532, total_train_loss=0.07532]


TRAIN	f1 : 0.91124


100%|██████████| 134/134 [00:07<00:00, 17.29it/s, epoch=26/30, val_loss=0.63885, total_val_loss=0.63885]


VAL	f1 : 0.57397


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=27/30, train_loss=0.05373, total_train_loss=0.05373]


TRAIN	f1 : 0.93240


100%|██████████| 134/134 [00:07<00:00, 17.20it/s, epoch=27/30, val_loss=0.59522, total_val_loss=0.59522]


VAL	f1 : 0.55904


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=28/30, train_loss=0.11536, total_train_loss=0.11536]


TRAIN	f1 : 0.86493


100%|██████████| 134/134 [00:07<00:00, 17.12it/s, epoch=28/30, val_loss=0.74549, total_val_loss=0.74549]


VAL	f1 : 0.44227


100%|██████████| 401/401 [01:42<00:00,  3.92it/s, epoch=29/30, train_loss=0.11954, total_train_loss=0.11954]


TRAIN	f1 : 0.86647


100%|██████████| 134/134 [00:07<00:00, 17.14it/s, epoch=29/30, val_loss=0.66838, total_val_loss=0.66838]


VAL	f1 : 0.51935
