### Python import

In [None]:
import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

import torch
import albumentations as A
from torch.utils.data import DataLoader, WeightedRandomSampler
from adabelief_pytorch import AdaBelief

from module.config import config_from_yaml
from module.helper import print_time, AverageMeter
from module.dataset import CLF_Dataset
from module.model import CLF_MODEL
from module.scheduler import CosineAnnealingWarmupRestarts

CFG = config_from_yaml('config.yaml')

### Transform

In [None]:
train_transform = A.Compose([
    A.Resize(CFG.TRAIN.IMG_SIZE, CFG.TRAIN.IMG_SIZE, p=1),
    A.ImageCompression(quality_lower=99, quality_upper=100, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.8),
    A.ShiftScaleRotate(shift_limit=0, scale_limit=[-0.5, 0.0], rotate_limit=0, interpolation=0, border_mode=0, p=1.0),
])

valid_transform = A.Compose([
    A.Resize(CFG.TRAIN.IMG_SIZE, CFG.TRAIN.IMG_SIZE, p=1),
])

### Sampler

In [None]:
def make_sampler(dataset):
    count   = np.bincount(dataset.labels).tolist()
    weights = 1./torch.tensor(count, dtype=torch.float)
    weights = weights[dataset.labels]
    sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights))
    return sampler

def collate_fn(batch):
    return tuple(zip(*batch))

### Data

In [None]:
# All Data
np.random.seed(CFG.BASE.SEED)

df        = pd.read_csv(CFG.DATA.TRAIN_DATA_PATH)
df_tmp    = df[['img_name', 'defect']].drop_duplicates()
img_names = df_tmp['img_name'].values
defects   = df_tmp['defect'].values

train_img_names = img_names
train_defects   = defects

valid_index = np.sort(np.random.choice(range(len(train_img_names)), size=4000, replace=False))
valid_img_names = img_names[valid_index]
valid_defects = defects[valid_index]

train_dataset    = CLF_Dataset(train_img_names, train_defects, transforms=train_transform)
valid_dataset    = CLF_Dataset(valid_img_names, valid_defects, transforms=valid_transform)

train_dataloader = DataLoader(train_dataset, 
                              batch_size  = CFG.TRAIN.BATCH_SIZE, 
                              shuffle     = False,  
                              num_workers = CFG.TRAIN.WORKERS,
                              pin_memory  = CFG.TRAIN.PIN_MEMORY,
                              sampler     = make_sampler(train_dataset),
                              collate_fn  = collate_fn
                             )
valid_dataloader = DataLoader(valid_dataset, 
                              batch_size  = CFG.TRAIN.BATCH_SIZE, 
                              shuffle     = False, 
                              num_workers = CFG.TRAIN.WORKERS, 
                              pin_memory  = CFG.TRAIN.PIN_MEMORY,
                              collate_fn  = collate_fn)    

### Train

In [None]:
model = CLF_MODEL(name        = CFG.TRAIN.CLF.NAME,
                  num_channel = CFG.DATA.N_CHANNEL, 
                  num_class   = CFG.DATA.N_CLASS,
                  image_mean  = CFG.DATA.N_CHANNEL * [CFG.TRAIN.IMG_MEAN],
                  image_std   = CFG.DATA.N_CHANNEL * [CFG.TRAIN.IMG_STD],
                  smoothing   = CFG.TRAIN.CLF.SMOOTHING,
                  pretrained  = CFG.TRAIN.CLF.PRETRAINED)
# model.load_state_dict(torch.load('clf_model/clf012_0.9836.pth', map_location=CFG.TRAIN.DEVICE))
model.to(CFG.TRAIN.DEVICE)

optimizer = AdaBelief(model.parameters(), 
                      lr               = CFG.OPTIMIZER.CLF.LR,
                      eps              = CFG.OPTIMIZER.CLF.EPSILON,
                      weight_decay     = CFG.OPTIMIZER.CLF.WEIGHT_DECAY,
                      weight_decouple  = CFG.OPTIMIZER.CLF.WEIGHT_DECOUPLE,
                      rectify          = CFG.OPTIMIZER.CLF.RECTIFY,
                      print_change_log = False)

scheduler = CosineAnnealingWarmupRestarts(optimizer, 
                                          first_cycle_steps = CFG.SCHEDULER.CLF.FIRST_CYCLE_STEPS,
                                          warmup_steps      = CFG.SCHEDULER.CLF.WARMUP_STEPS,
                                          max_lr            = CFG.OPTIMIZER.CLF.LR,
                                          min_lr            = CFG.SCHEDULER.CLF.MIN_LR)

scaler = torch.cuda.amp.GradScaler()

In [None]:
def train_one_epoch(epoch, dataloader, model, optimizer, scheduler, device):    
    model.train()    
    loss_meter = AverageMeter()
    acc_meter  = AverageMeter()
    pbar = tqdm(dataloader, desc=f'Epoch [{epoch}]', leave=True)
    for images, targets in pbar:  
        optimizer.zero_grad()        
        if scheduler is not None:
            scheduler.step(epoch+loss_meter.count/len(dataloader))

        images    = [image.to(device) for image in images]
        targets   = [{k: v.to(device) for k, v in t.items()} for t in targets]
        with torch.cuda.amp.autocast():
            loss_dict = model(images, targets)
            loss      = loss_dict['loss']
            acc       = loss_dict['acc']
            loss_meter.update(loss.item(), n=1)
            acc_meter.update(acc.item(), n=1)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()        

        pbar.set_postfix({'Lr': optimizer.param_groups[0]['lr'], 'Loss': loss_meter.avg, 'Acc': acc_meter.avg})
 
@torch.no_grad()
def valid_one_epoch(dataloader, model, device, epoch=0):
    model.eval()        
    loss_meter = AverageMeter()
    acc_meter  = AverageMeter()    
    for images, targets in dataloader:
        images    = [image.to(device) for image in images]
        targets   = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        loss      = loss_dict['loss']
        acc       = loss_dict['acc']        
        loss_meter.update(loss.item(), n=1)
        acc_meter.update(acc.item(), n=1)
        
    message = f'Epoch: {epoch}, Loss: {loss_meter.avg:.4f}, ACC: {acc_meter.avg:.4f}'
    print_time(message=message, new_line=False)
    return acc_meter.avg

In [None]:
# Make Log File
log_name = str(datetime.now())[:19]
log_name = log_name.replace(' ', '-').replace(':', '-')
log_name = 'CLF_' + log_name + '.txt'
sys.stdout = open(log_name, 'w')

print_time(message='START: CLF', new_line=False)
print(CFG)

# Make Dir
os.makedirs(CFG.MODEL.CLF_FOLDER, exist_ok=True)
best_score = 0

# Loop
for epoch in range(0, CFG.TRAIN.CLF.EPOCHS):
    train_one_epoch(epoch, train_dataloader, model, optimizer, scheduler, CFG.TRAIN.DEVICE)
    score = valid_one_epoch(valid_dataloader, model, CFG.TRAIN.DEVICE, epoch=epoch)
    
    # Last File Save
    model_path = os.path.join(CFG.MODEL.CLF_FOLDER, f'clf{epoch:03d}_{score:.4f}.pth')    
    torch.save(model.state_dict(), model_path)
    
    # Best File Save
    if score > best_score:
        best_score = score
        model_path = os.path.join(CFG.MODEL.CLF_FOLDER, 'best.pth')    
        torch.save(model.state_dict(), model_path)
        
print_time(message='END: CLF')