In [2]:
# import dependent libraries
from torch.utils.data import Dataset,DataLoader
from sklearn.model_selection import StratifiedKFold
from torch.cuda.amp import GradScaler
from torch import nn
from tqdm import tqdm
import torch
import timm
import cv2
import pandas as pd
import numpy as np
from utils import utils
from imp import reload
from albumentations.pytorch import ToTensorV2
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout,
    ShiftScaleRotate, CenterCrop, Resize
)
reload(utils)
rand_seed = 666
utils.seed_everything(rand_seed)

In [1]:
!pip install timm

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting timm
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/72/ed/358a8bc5685c31c0fe7765351b202cf6a8c087893b5d2d64f63c950f8beb/timm-0.6.7-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 67.3 MB/s eta 0:00:01
Installing collected packages: timm
Successfully installed timm-0.6.7


In [3]:
train_img_path = '/home/featurize/data/train_images'  
train_csv_path = '/home/featurize/data/train.csv'   

In [4]:
print(train_csv_path)

/home/featurize/data/train.csv


In [5]:
# Training set data augmentation
def get_train_transforms():
    return Compose([
        RandomResizedCrop(CFG['img_size'], CFG['img_size']),
        Transpose(p=0.5),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        ShiftScaleRotate(p=0.5),
        HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        CoarseDropout(p=0.5),
        Cutout(p=0.5),
        ToTensorV2(p=1.0),
    ], p=1.)

# Validation set data augmentation
def get_valid_transforms():
    return Compose([
        CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
        Resize(CFG['img_size'], CFG['img_size']),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2(p=1.0),
    ], p=1.)

In [6]:
# model building
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
    def forward(self, x):
        x = self.model(x)
        return x

In [7]:
!featurize dataset download 17bd6643-4e22-423b-95c7-3f82601931bb

100%|██████████████████████████████████████| 6.19G/6.19G [00:21<00:00, 285MiB/s]
🍬  下载完成，正在解压...
🏁  数据集已经成功添加


In [11]:
# build data
CFG = {
    'img_size' : 512,
    'epochs': 10,
    'fold_num': 5,
    'device': 'cuda',
    'model_arch': 'tf_efficientnet_b4_ns',
    'train_bs' : 16,
    'valid_bs' : 16,
    'num_workers' : 0,
    'lr': 1e-4,
    'weight_decay': 1e-6,
    'T_0': 10,
    'min_lr': 1e-6,
}
train = pd.read_csv(train_csv_path)
folds = StratifiedKFold(n_splits=CFG['fold_num'],
                        shuffle=True,
                        random_state=rand_seed).split(
                            np.arange(train.shape[0]), train.label.values)
trn_transform = get_train_transforms()
val_transform = get_valid_transforms()



In [9]:
for fold, (trn_idx, val_idx) in enumerate(folds):
    print(fold)

0
1
2
3
4


In [13]:
fold_num = 0
for fold, (trn_idx, val_idx) in enumerate(folds):
    print('Training with {} started'.format(fold))
    print('Train : {}, Val : {}'.format(len(trn_idx), len(val_idx)))
    train_loader, val_loader = utils.prepare_dataloader(train,
                                                        trn_idx,
                                                        val_idx,
                                                        data_root = train_img_path,
                                                        trn_transform = trn_transform,
                                                        val_transform = val_transform, 
                                                        bs = CFG['train_bs'], 
                                                        n_job = CFG['num_workers'])

    device = torch.device(CFG['device'])

    model = CassvaImgClassifier(CFG['model_arch'],
                                train.label.nunique(),
                                pretrained=True).to(device)
    scaler = GradScaler()
    optimizer = torch.optim.Adam(model.parameters(),
                                    lr=CFG['lr'],
                                    weight_decay=CFG['weight_decay'])

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=CFG['T_0'],
        T_mult=1,
        eta_min=CFG['min_lr'],
        last_epoch=-1)

    loss_tr = nn.CrossEntropyLoss().to(
        device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    for epoch in range(CFG['epochs']):
        utils.train_one_epoch(epoch,
                            model,
                            loss_tr,
                            optimizer,
                            train_loader,
                            device,
                            scaler,
                            scheduler=scheduler,
                            schd_batch_update=False)

        with torch.no_grad():
            utils.valid_one_epoch(epoch,
                                model,
                                loss_fn,
                                val_loader,
                                device)

        torch.save(
            model.state_dict(),
            '/home/featurize/work/model/{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))

    del model, optimizer, train_loader, val_loader, scaler, scheduler
    torch.cuda.empty_cache()


Training with 0 started
Train : 17117, Val : 4280


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth" to /home/featurize/.cache/torch/hub/checkpoints/tf_efficientnet_b4_ns-d6313a46.pth
epoch 0 loss: 0.4952: 100%|██████████| 1070/1070 [14:09<00:00,  1.26it/s]
epoch 0 loss: 0.3786: 100%|██████████| 268/268 [01:35<00:00,  2.81it/s]


validation multi-class accuracy = 0.8680


epoch 1 loss: 0.4104: 100%|██████████| 1070/1070 [14:09<00:00,  1.26it/s]
epoch 1 loss: 0.3321: 100%|██████████| 268/268 [01:34<00:00,  2.85it/s]


validation multi-class accuracy = 0.8871


epoch 2 loss: 0.3973: 100%|██████████| 1070/1070 [14:02<00:00,  1.27it/s]
epoch 2 loss: 0.3360: 100%|██████████| 268/268 [01:33<00:00,  2.86it/s]


validation multi-class accuracy = 0.8890


epoch 3 loss: 0.3529: 100%|██████████| 1070/1070 [13:51<00:00,  1.29it/s]
epoch 3 loss: 0.3343: 100%|██████████| 268/268 [01:34<00:00,  2.85it/s]


validation multi-class accuracy = 0.8864


epoch 4 loss: 0.3277: 100%|██████████| 1070/1070 [13:57<00:00,  1.28it/s]
epoch 4 loss: 0.3193: 100%|██████████| 268/268 [01:33<00:00,  2.87it/s]


validation multi-class accuracy = 0.8930


epoch 5 loss: 0.3154: 100%|██████████| 1070/1070 [14:00<00:00,  1.27it/s]
epoch 5 loss: 0.3249: 100%|██████████| 268/268 [01:34<00:00,  2.85it/s]


validation multi-class accuracy = 0.8867


epoch 6 loss: 0.2908: 100%|██████████| 1070/1070 [13:56<00:00,  1.28it/s]
epoch 6 loss: 0.3241: 100%|██████████| 268/268 [01:33<00:00,  2.88it/s]


validation multi-class accuracy = 0.8900


epoch 7 loss: 0.2794: 100%|██████████| 1070/1070 [13:56<00:00,  1.28it/s]
epoch 7 loss: 0.3294: 100%|██████████| 268/268 [01:33<00:00,  2.85it/s]


validation multi-class accuracy = 0.8857


epoch 8 loss: 0.2775: 100%|██████████| 1070/1070 [13:49<00:00,  1.29it/s]
epoch 8 loss: 0.3351: 100%|██████████| 268/268 [01:33<00:00,  2.86it/s]


validation multi-class accuracy = 0.8857


epoch 9 loss: 0.2682: 100%|██████████| 1070/1070 [13:53<00:00,  1.28it/s]
epoch 9 loss: 0.3356: 100%|██████████| 268/268 [01:34<00:00,  2.84it/s]


validation multi-class accuracy = 0.8848
Training with 1 started
Train : 17118, Val : 4279


epoch 0 loss: 0.4601: 100%|██████████| 1070/1070 [13:23<00:00,  1.33it/s]
epoch 0 loss: 0.3737: 100%|██████████| 268/268 [01:28<00:00,  3.02it/s]


validation multi-class accuracy = 0.8789


epoch 1 loss: 0.3999: 100%|██████████| 1070/1070 [13:18<00:00,  1.34it/s]
epoch 1 loss: 0.3680: 100%|██████████| 268/268 [01:26<00:00,  3.09it/s]


validation multi-class accuracy = 0.8768


epoch 2 loss: 0.3840: 100%|██████████| 1070/1070 [13:13<00:00,  1.35it/s]
epoch 2 loss: 0.3299: 100%|██████████| 268/268 [01:27<00:00,  3.05it/s]


validation multi-class accuracy = 0.8883


epoch 3 loss: 0.3590: 100%|██████████| 1070/1070 [13:26<00:00,  1.33it/s]
epoch 3 loss: 0.3293: 100%|██████████| 268/268 [01:26<00:00,  3.08it/s]


validation multi-class accuracy = 0.8934


epoch 4 loss: 0.3467: 100%|██████████| 1070/1070 [13:20<00:00,  1.34it/s]
epoch 4 loss: 0.3339: 100%|██████████| 268/268 [01:27<00:00,  3.07it/s]


validation multi-class accuracy = 0.8913


epoch 5 loss: 0.2928: 100%|██████████| 1070/1070 [13:13<00:00,  1.35it/s]
epoch 5 loss: 0.3197: 100%|██████████| 268/268 [01:27<00:00,  3.07it/s]


validation multi-class accuracy = 0.8958


epoch 6 loss: 0.3013: 100%|██████████| 1070/1070 [13:18<00:00,  1.34it/s]
epoch 6 loss: 0.3241: 100%|██████████| 268/268 [01:31<00:00,  2.93it/s]


validation multi-class accuracy = 0.8939


epoch 7 loss: 0.2802: 100%|██████████| 1070/1070 [13:11<00:00,  1.35it/s]
epoch 7 loss: 0.3323: 100%|██████████| 268/268 [01:26<00:00,  3.08it/s]


validation multi-class accuracy = 0.8946


epoch 8 loss: 0.3290: 100%|██████████| 268/268 [01:27<00:00,  3.06it/s]]


validation multi-class accuracy = 0.8960


epoch 9 loss: 0.2651: 100%|██████████| 1070/1070 [13:22<00:00,  1.33it/s]
epoch 9 loss: 0.3362: 100%|██████████| 268/268 [01:27<00:00,  3.06it/s]


validation multi-class accuracy = 0.8932
Training with 2 started
Train : 17118, Val : 4279


epoch 0 loss: 0.4903: 100%|██████████| 1070/1070 [13:58<00:00,  1.28it/s]
epoch 0 loss: 0.3718: 100%|██████████| 268/268 [01:33<00:00,  2.87it/s]


validation multi-class accuracy = 0.8687


epoch 1 loss: 0.3923: 100%|██████████| 1070/1070 [13:49<00:00,  1.29it/s]
epoch 1 loss: 0.3565: 100%|██████████| 268/268 [01:35<00:00,  2.82it/s]


validation multi-class accuracy = 0.8759


epoch 2 loss: 0.3549: 100%|██████████| 1070/1070 [14:06<00:00,  1.26it/s]
epoch 2 loss: 0.3313: 100%|██████████| 268/268 [01:36<00:00,  2.78it/s]


validation multi-class accuracy = 0.8850


epoch 3 loss: 0.3668: 100%|██████████| 1070/1070 [13:46<00:00,  1.30it/s]
epoch 3 loss: 0.3349: 100%|██████████| 268/268 [01:35<00:00,  2.82it/s]


validation multi-class accuracy = 0.8832


epoch 4 loss: 0.3378: 100%|██████████| 1070/1070 [14:08<00:00,  1.26it/s]
epoch 4 loss: 0.3266: 100%|██████████| 268/268 [01:35<00:00,  2.82it/s]


validation multi-class accuracy = 0.8927


epoch 5 loss: 0.3292: 100%|██████████| 1070/1070 [13:52<00:00,  1.29it/s]
epoch 5 loss: 0.3281: 100%|██████████| 268/268 [01:34<00:00,  2.85it/s]


validation multi-class accuracy = 0.8846


epoch 6 loss: 0.2935: 100%|██████████| 1070/1070 [14:06<00:00,  1.26it/s]
epoch 6 loss: 0.3248: 100%|██████████| 268/268 [01:35<00:00,  2.82it/s]


validation multi-class accuracy = 0.8862


epoch 7 loss: 0.2693: 100%|██████████| 1070/1070 [14:10<00:00,  1.26it/s]
epoch 7 loss: 0.3296: 100%|██████████| 268/268 [01:33<00:00,  2.86it/s]


validation multi-class accuracy = 0.8892


epoch 8 loss: 0.2628: 100%|██████████| 1070/1070 [13:52<00:00,  1.28it/s]
epoch 8 loss: 0.3310: 100%|██████████| 268/268 [01:34<00:00,  2.83it/s]


validation multi-class accuracy = 0.8890


epoch 9 loss: 0.2667: 100%|██████████| 1070/1070 [13:54<00:00,  1.28it/s]
epoch 9 loss: 0.3291: 100%|██████████| 268/268 [01:33<00:00,  2.86it/s]


validation multi-class accuracy = 0.8878
Training with 3 started
Train : 17118, Val : 4279


epoch 0 loss: 0.4726: 100%|██████████| 1070/1070 [13:49<00:00,  1.29it/s]
epoch 0 loss: 0.4085: 100%|██████████| 268/268 [01:34<00:00,  2.84it/s]


validation multi-class accuracy = 0.8645


epoch 1 loss: 0.4116: 100%|██████████| 1070/1070 [13:47<00:00,  1.29it/s]
epoch 1 loss: 0.3637: 100%|██████████| 268/268 [01:33<00:00,  2.87it/s]


validation multi-class accuracy = 0.8787


epoch 2 loss: 0.3842: 100%|██████████| 1070/1070 [13:58<00:00,  1.28it/s]
epoch 2 loss: 0.3593: 100%|██████████| 268/268 [01:33<00:00,  2.87it/s]


validation multi-class accuracy = 0.8822


epoch 3 loss: 0.3573: 100%|██████████| 1070/1070 [13:52<00:00,  1.29it/s]
epoch 3 loss: 0.3573: 100%|██████████| 268/268 [01:35<00:00,  2.81it/s]


validation multi-class accuracy = 0.8822


epoch 4 loss: 0.3460: 100%|██████████| 1070/1070 [13:53<00:00,  1.28it/s]
epoch 4 loss: 0.3608: 100%|██████████| 268/268 [01:34<00:00,  2.82it/s]


validation multi-class accuracy = 0.8834


epoch 5 loss: 0.3103: 100%|██████████| 1070/1070 [13:59<00:00,  1.27it/s]
epoch 5 loss: 0.3394: 100%|██████████| 268/268 [01:34<00:00,  2.84it/s]


validation multi-class accuracy = 0.8930


epoch 6 loss: 0.3013: 100%|██████████| 1070/1070 [13:59<00:00,  1.27it/s]
epoch 6 loss: 0.3528: 100%|██████████| 268/268 [01:34<00:00,  2.83it/s]


validation multi-class accuracy = 0.8890


epoch 7 loss: 0.2578: 100%|██████████| 1070/1070 [13:56<00:00,  1.28it/s]
epoch 7 loss: 0.3587: 100%|██████████| 268/268 [01:33<00:00,  2.87it/s]


validation multi-class accuracy = 0.8846


epoch 8 loss: 0.2732: 100%|██████████| 1070/1070 [13:56<00:00,  1.28it/s]
epoch 8 loss: 0.3577: 100%|██████████| 268/268 [01:35<00:00,  2.80it/s]


validation multi-class accuracy = 0.8841


epoch 9 loss: 0.2688: 100%|██████████| 1070/1070 [13:56<00:00,  1.28it/s]
epoch 9 loss: 0.3611: 100%|██████████| 268/268 [01:35<00:00,  2.80it/s]


validation multi-class accuracy = 0.8806
