In [1]:
# import dependent libraries
from torch.utils.data import Dataset,DataLoader
from sklearn.model_selection import StratifiedKFold
from torch.cuda.amp import GradScaler
from torch import nn
from tqdm import tqdm
import torch
import timm
import cv2
import pandas as pd
import numpy as np
from utils import utils
from imp import reload
from albumentations.pytorch import ToTensorV2
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout,
    ShiftScaleRotate, CenterCrop, Resize
)
reload(utils)
rand_seed = 666
utils.seed_everything(rand_seed)

In [1]:
!pip install timm

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [8]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
train_img_path = '/home/featurize/data/train_images'  
train_csv_path = '/home/featurize/data/train.csv'   

In [3]:
print(train_csv_path)

/home/featurize/data/train.csv


In [4]:
# Training set data augmentation
def get_train_transforms():
    return Compose([
        RandomResizedCrop(CFG['img_size'], CFG['img_size']),
        Transpose(p=0.5),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        ShiftScaleRotate(p=0.5),
        HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        CoarseDropout(p=0.5),
        Cutout(p=0.5),
        ToTensorV2(p=1.0),
    ], p=1.)

# Validation set data augmentation
def get_valid_transforms():
    return Compose([
        CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
        Resize(CFG['img_size'], CFG['img_size']),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2(p=1.0),
    ], p=1.)

In [5]:
# model building
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
    def forward(self, x):
        x = self.model(x)
        return x

In [6]:
# build data
CFG = {
    'img_size' : 512,
    'epochs': 10,
    'fold_num': 5,
    'device': 'cuda',
    'model_arch': 'tf_efficientnet_b5_ns',
    'train_bs' : 16,
    'valid_bs' : 16,
    'num_workers' : 0,
    'lr': 1e-4,
    'weight_decay': 1e-6,
    'T_0': 10,
    'min_lr': 1e-6,
}
train = pd.read_csv(train_csv_path)
folds = StratifiedKFold(n_splits=CFG['fold_num'],
                        shuffle=True,
                        random_state=rand_seed).split(
                            np.arange(train.shape[0]), train.label.values)
trn_transform = get_train_transforms()
val_transform = get_valid_transforms()



In [None]:
fold_num = 0
for fold, (trn_idx, val_idx) in enumerate(folds):
    print('Training with {} started'.format(fold))
    print('Train : {}, Val : {}'.format(len(trn_idx), len(val_idx)))
    train_loader, val_loader = utils.prepare_dataloader(train,
                                                        trn_idx,
                                                        val_idx,
                                                        data_root = train_img_path,
                                                        trn_transform = trn_transform,
                                                        val_transform = val_transform, 
                                                        bs = CFG['train_bs'], 
                                                        n_job = CFG['num_workers'])

    device = torch.device(CFG['device'])

    model = CassvaImgClassifier(CFG['model_arch'],
                                train.label.nunique(),
                                pretrained=True).to(device)
    scaler = GradScaler()
    optimizer = torch.optim.Adam(model.parameters(),
                                    lr=CFG['lr'],
                                    weight_decay=CFG['weight_decay'])

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=CFG['T_0'],
        T_mult=1,
        eta_min=CFG['min_lr'],
        last_epoch=-1)

    loss_tr = nn.CrossEntropyLoss().to(
        device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    for epoch in range(CFG['epochs']):
        utils.train_one_epoch(epoch,
                            model,
                            loss_tr,
                            optimizer,
                            train_loader,
                            device,
                            scaler,
                            scheduler=scheduler,
                            schd_batch_update=False)

        with torch.no_grad():
            utils.valid_one_epoch(epoch,
                                model,
                                loss_fn,
                                val_loader,
                                device)

        torch.save(
            model.state_dict(),
            '/home/featurize/work/model_b5/{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))

    del model, optimizer, train_loader, val_loader, scaler, scheduler
    torch.cuda.empty_cache()


Training with 0 started
Train : 17117, Val : 4280


epoch 0 loss: 0.4390: 100%|██████████| 1070/1070 [17:56<00:00,  1.01s/it]
epoch 0 loss: 0.4059: 100%|██████████| 268/268 [01:55<00:00,  2.32it/s]


validation multi-class accuracy = 0.8561


epoch 1 loss: 0.3794: 100%|██████████| 1070/1070 [17:59<00:00,  1.01s/it]
epoch 1 loss: 0.3521: 100%|██████████| 268/268 [01:50<00:00,  2.42it/s]


validation multi-class accuracy = 0.8853


epoch 2 loss: 0.3587: 100%|██████████| 1070/1070 [18:09<00:00,  1.02s/it]
epoch 2 loss: 0.3729: 100%|██████████| 268/268 [01:49<00:00,  2.45it/s]


validation multi-class accuracy = 0.8682


epoch 3 loss: 0.3605: 100%|██████████| 1070/1070 [18:24<00:00,  1.03s/it]
epoch 3 loss: 0.3351: 100%|██████████| 268/268 [01:48<00:00,  2.47it/s]


validation multi-class accuracy = 0.8829


epoch 4 loss: 0.3025: 100%|██████████| 1070/1070 [17:54<00:00,  1.00s/it]
epoch 4 loss: 0.3435: 100%|██████████| 268/268 [01:46<00:00,  2.51it/s]


validation multi-class accuracy = 0.8879


epoch 5 loss: 0.3029: 100%|██████████| 1070/1070 [18:06<00:00,  1.02s/it]
epoch 5 loss: 0.3391: 100%|██████████| 268/268 [01:46<00:00,  2.51it/s]


validation multi-class accuracy = 0.8850


epoch 6 loss: 0.2642: 100%|██████████| 1070/1070 [18:01<00:00,  1.01s/it]
epoch 6 loss: 0.3432: 100%|██████████| 268/268 [01:47<00:00,  2.50it/s]


validation multi-class accuracy = 0.8921


epoch 7 loss: 0.2498: 100%|██████████| 1070/1070 [17:51<00:00,  1.00s/it]
epoch 7 loss: 0.3565: 100%|██████████| 268/268 [01:45<00:00,  2.54it/s]


validation multi-class accuracy = 0.8897


epoch 8 loss: 0.2315: 100%|██████████| 1070/1070 [18:02<00:00,  1.01s/it]
epoch 8 loss: 0.3732: 100%|██████████| 268/268 [01:46<00:00,  2.51it/s]


validation multi-class accuracy = 0.8818


epoch 9 loss: 0.2257: 100%|██████████| 1070/1070 [17:59<00:00,  1.01s/it]
epoch 9 loss: 0.3781: 100%|██████████| 268/268 [01:45<00:00,  2.55it/s]


validation multi-class accuracy = 0.8808
Training with 1 started
Train : 17117, Val : 4280


epoch 0 loss: 0.4551: 100%|██████████| 1070/1070 [17:18<00:00,  1.03it/s]
epoch 0 loss: 0.3728: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8752


epoch 1 loss: 0.4157: 100%|██████████| 1070/1070 [17:03<00:00,  1.05it/s]
epoch 1 loss: 0.3367: 100%|██████████| 268/268 [01:43<00:00,  2.60it/s]


validation multi-class accuracy = 0.8862


epoch 2 loss: 0.3590: 100%|██████████| 1070/1070 [17:02<00:00,  1.05it/s]
epoch 2 loss: 0.3252: 100%|██████████| 268/268 [01:43<00:00,  2.59it/s]


validation multi-class accuracy = 0.8890


epoch 3 loss: 0.3267: 100%|██████████| 1070/1070 [17:11<00:00,  1.04it/s]
epoch 3 loss: 0.3188: 100%|██████████| 268/268 [01:42<00:00,  2.61it/s]


validation multi-class accuracy = 0.8886


epoch 4 loss: 0.3340: 100%|██████████| 1070/1070 [17:03<00:00,  1.05it/s]
epoch 4 loss: 0.3184: 100%|██████████| 268/268 [01:43<00:00,  2.60it/s]


validation multi-class accuracy = 0.8886


epoch 5 loss: 0.3111: 100%|██████████| 1070/1070 [17:04<00:00,  1.04it/s]
epoch 5 loss: 0.3243: 100%|██████████| 268/268 [01:44<00:00,  2.57it/s]


validation multi-class accuracy = 0.8890


epoch 6 loss: 0.2966: 100%|██████████| 1070/1070 [17:17<00:00,  1.03it/s]
epoch 6 loss: 0.3147: 100%|██████████| 268/268 [01:43<00:00,  2.59it/s]


validation multi-class accuracy = 0.8923


epoch 7 loss: 0.2599: 100%|██████████| 1070/1070 [17:11<00:00,  1.04it/s]
epoch 7 loss: 0.3350: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8914


epoch 8 loss: 0.2454: 100%|██████████| 1070/1070 [17:08<00:00,  1.04it/s]
epoch 8 loss: 0.3386: 100%|██████████| 268/268 [01:43<00:00,  2.59it/s]


validation multi-class accuracy = 0.8904


epoch 9 loss: 0.2227: 100%|██████████| 1070/1070 [17:03<00:00,  1.05it/s]
epoch 9 loss: 0.3382: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8932
Training with 2 started
Train : 17118, Val : 4279


epoch 0 loss: 0.4573: 100%|██████████| 1070/1070 [17:13<00:00,  1.04it/s]
epoch 0 loss: 0.3687: 100%|██████████| 268/268 [01:44<00:00,  2.57it/s]


validation multi-class accuracy = 0.8803


epoch 1 loss: 0.4013: 100%|██████████| 1070/1070 [17:10<00:00,  1.04it/s]
epoch 1 loss: 0.3642: 100%|██████████| 268/268 [01:44<00:00,  2.58it/s]


validation multi-class accuracy = 0.8738


epoch 2 loss: 0.3472: 100%|██████████| 1070/1070 [17:08<00:00,  1.04it/s]
epoch 2 loss: 0.3231: 100%|██████████| 268/268 [01:44<00:00,  2.57it/s]


validation multi-class accuracy = 0.8881


epoch 3 loss: 0.3261: 100%|██████████| 1070/1070 [17:09<00:00,  1.04it/s]
epoch 3 loss: 0.3140: 100%|██████████| 268/268 [01:44<00:00,  2.57it/s]


validation multi-class accuracy = 0.8979


epoch 4 loss: 0.3277: 100%|██████████| 1070/1070 [16:54<00:00,  1.05it/s]
epoch 4 loss: 0.3201: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8925


epoch 5 loss: 0.3012: 100%|██████████| 1070/1070 [17:01<00:00,  1.05it/s]
epoch 5 loss: 0.3571: 100%|██████████| 268/268 [01:43<00:00,  2.59it/s]


validation multi-class accuracy = 0.8778


epoch 6 loss: 0.2895: 100%|██████████| 1070/1070 [17:04<00:00,  1.04it/s]
epoch 6 loss: 0.3282: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8930


epoch 7 loss: 0.2331: 100%|██████████| 1070/1070 [16:49<00:00,  1.06it/s]
epoch 7 loss: 0.3376: 100%|██████████| 268/268 [01:43<00:00,  2.60it/s]


validation multi-class accuracy = 0.8885


epoch 8 loss: 0.2337: 100%|██████████| 1070/1070 [17:09<00:00,  1.04it/s]
epoch 8 loss: 0.3541: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8839


epoch 9 loss: 0.2123: 100%|██████████| 1070/1070 [17:07<00:00,  1.04it/s]
epoch 9 loss: 0.3444: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8878
Training with 3 started
Train : 17118, Val : 4279


epoch 0 loss: 0.4265: 100%|██████████| 1070/1070 [17:02<00:00,  1.05it/s]
epoch 0 loss: 0.3568: 100%|██████████| 268/268 [01:44<00:00,  2.57it/s]


validation multi-class accuracy = 0.8745


epoch 1 loss: 0.4127: 100%|██████████| 1070/1070 [17:10<00:00,  1.04it/s]
epoch 1 loss: 0.3516: 100%|██████████| 268/268 [01:43<00:00,  2.58it/s]


validation multi-class accuracy = 0.8832


epoch 2 loss: 0.3600: 100%|██████████| 1070/1070 [17:02<00:00,  1.05it/s]
epoch 2 loss: 0.3377: 100%|██████████| 268/268 [01:43<00:00,  2.59it/s]


validation multi-class accuracy = 0.8839


epoch 3 loss: 0.3519:  84%|████████▎ | 895/1070 [14:23<02:47,  1.04it/s]