In [1]:
import os, random
import numpy as np
import pandas as pd
import bloscpack as bp
from sklearn.model_selection import StratifiedKFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

from sklearn.metrics import recall_score

import imgaug as ia
import imgaug.augmenters as iaa

import torch
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils import clip_grad_value_

from optim import Over9000

from data import Bengaliai_DS
from models_mg import Simple50GeM
from mixup_pytorch_utils import mixup, mixup_loss
from loss import CenterLoss, AngularPenaltySMLoss
import utils

import cv2
cv2.setNumThreads(1)

---

In [2]:
SEED = 19841202

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

---
### data

#### augmentation

In [3]:
augs = iaa.SomeOf(
    (0, 2),
    [
        iaa.SomeOf(
            (1, 2),
            [
                iaa.OneOf(
                    [
                        iaa.Affine(scale={"x": (0.8, 1.), "y": (0.8, 1.)}, rotate=(-15, 15), shear=(-15, 15)),
                        iaa.PerspectiveTransform(scale=.08, keep_size=True),
                    ]
                ),
                iaa.PiecewiseAffine(scale=.04),
            ],
            random_order=True
        ),
        iaa.DirectedEdgeDetect(alpha=(.6, .8), direction=(0.0, 1.0)),
    ],
    random_order=True
)

In [4]:
pdf = pd.read_csv('../input/train.csv')
unique_grapheme = pdf['grapheme'].unique()
grapheme_code = dict([(g, c) for g, c in zip(unique_grapheme, np.arange(unique_grapheme.shape[0]))])
pdf['grapheme_code'] = [grapheme_code[g] for g in pdf['grapheme']]

skf = StratifiedKFold(n_splits=7, shuffle=True, random_state=19841202)
for trn_ndx, vld_ndx in skf.split(pdf['grapheme_code'], pdf['grapheme_code']):
    break
    
trn_pdf = pdf.iloc[trn_ndx, :]
trn_pdf.reset_index(inplace=True, drop=True)
imgs = bp.unpack_ndarray_from_file('../features/train_images_size128_pad3.bloscpack')
lbls = pdf.loc[:, ['grapheme_root', 'vowel_diacritic', 'consonant_diacritic', 'grapheme_code']].values

trn_imgs = imgs[trn_ndx]
trn_lbls = lbls[trn_ndx]
vld_imgs = imgs[vld_ndx]
vld_lbls = lbls[vld_ndx]

In [5]:
training_set = Bengaliai_DS(trn_imgs, trn_lbls, transform=augs)
validation_set = Bengaliai_DS(vld_imgs, vld_lbls)

batch_size = 64

training_loader = DataLoader(training_set, batch_size=batch_size, num_workers=6, shuffle=True) # , sampler=sampler
validation_loader = DataLoader(validation_set, batch_size=batch_size, num_workers=6, shuffle=False)

---
### model

In [6]:
N_GRAPHEME = 168
N_VOWEL = 11
N_CONSONANT = 7

N_EPOCHS = 160

feat_loss_weight = .1

checkpoint_name = 'seresnext50_purepytorch_cutmix_aug_epoch{:d}.pth'

In [7]:
classifier = Simple50GeM(output_features=True)#.cuda()
feat_loser = AngularPenaltySMLoss(in_features=2048, out_features=1295)

optimizer_classifier = Over9000(classifier.parameters(), lr=.01)
optimizer_featurelsr = Over9000(feat_loser.parameters(), lr=.01)

In [8]:
logger = utils.csv_logger(['training_loss', 'validation_loss', 'GRAPHEME_Recall', 'VOWEL_Recall', 'CONSONANT_Recall', 'Final_Recall'])

In [9]:
for i in range(N_EPOCHS):
    logger.new_epoch()
    # train
    classifier.train()
    
    epoch_trn_loss = []
    epoch_vld_loss = []
    epoch_vld_recall_g, epoch_vld_recall_v, epoch_vld_recall_c, epoch_vld_recall_all = [], [], [], []
    
    for j, (trn_imgs_batch, trn_lbls_batch) in enumerate(training_loader):
        # move to device
        trn_imgs_batch_device = trn_imgs_batch#.to(DEVICE)
        trn_lbls_batch_device = trn_lbls_batch#.to(DEVICE)
        
        # mixup
        trn_imgs_batch_device_mixup, trn_lbls_batch_device_shfl, gamma = mixup(trn_imgs_batch_device, trn_lbls_batch_device, .8)
        
        # forward pass
        logits_g, logits_v, logits_c, feats = classifier(trn_imgs_batch_device_mixup)
        
        loss_g = mixup_loss(logits_g, trn_lbls_batch_device[:, 0], trn_lbls_batch_device_shfl[:, 0], gamma).mean()
        loss_v = mixup_loss(logits_v, trn_lbls_batch_device[:, 1], trn_lbls_batch_device_shfl[:, 1], gamma).mean()
        loss_c = mixup_loss(logits_c, trn_lbls_batch_device[:, 2], trn_lbls_batch_device_shfl[:, 2], gamma).mean()
        #loss_cl = gamma*center_loser(feats, trn_lbls_batch_device[:, 3])
        loss_feat = (gamma*feat_loser(feats, trn_lbls_batch_device[:, 3]) + (1-gamma)*feat_loser(feats, trn_lbls_batch_device_shfl[:, 3])).mean()
        
        #break
        
        total_loss = .5*loss_g + .25*loss_v + .25*loss_c + feat_loss_weight*loss_feat
        
        optimizer_classifier.zero_grad()
        optimizer_featurelsr.zero_grad()
        
        total_loss.backward()
        clip_grad_value_(classifier.parameters(), 1.0)
        clip_grad_value_(feat_loser.parameters(), 1.0)
        
        optimizer_classifier.step()
        # by doing so, weight_cent would not impact on the learning of centers
        #for param in center_loser.parameters():
        #    param.grad.data *= (1. / center_loss_weight)
        
        optimizer_featurelsr.step()
        
        # record
        epoch_trn_loss.append(total_loss.item())
        
        utils.display_progress(len(training_loader), j+1, {'training_loss': epoch_trn_loss[-1]})
    
    #break
    # validation
    classifier.eval()
    
    with torch.no_grad():
        for k, (vld_imgs_batch, vld_lbls_batch) in enumerate(validation_loader):
            
            # move to device
            vld_imgs_batch_device = vld_imgs_batch#.cuda()
            vld_lbls_batch_device = vld_lbls_batch.detach().cpu().numpy()
            
            # forward pass
            logits_g, logits_v, logits_c, feats = classifier(vld_lbls_batch_device)
            
            # loss
            loss_g = F.cross_entropy(logits_g, vld_lbls_batch_device[:, 0])
            loss_v = F.cross_entropy(logits_v, vld_lbls_batch_device[:, 1])
            loss_c = F.cross_entropy(logits_c, vld_lbls_batch_device[:, 2])
            loss_feat = feat_loser(feats, vld_lbls_batch_device[:, 3])
            
            total_loss = .5*loss_g + .25*loss_v + .25*loss_c + feat_loss_weight*loss_feat
            # record
            epoch_vld_loss.append(total_loss.item())
            
            # metrics
            pred_g, pred_v, pred_c = logits_g.argmax(axis=1).detach().cpu().numpy(), logits_g.argmax(axis=1).detach().cpu().numpy(), logits_g.argmax(axis=1).detach().cpu().numpy()
            epoch_vld_recall_g.append(recall_score(pred_g, vld_lbls_batch_device[:, 0], average='macro'))
            epoch_vld_recall_v.append(recall_score(pred_v, vld_lbls_batch_device[:, 1], average='macro'))
            epoch_vld_recall_c.append(recall_score(pred_c, vld_lbls_batch_device[:, 2], average='macro'))
            
            # display progress
            utils.display_progress(len(validation_loader), k+1, {'validation_loss': epoch_vld_loss[-1]})
    break
    epoch_vld_recall_g, epoch_vld_recall_v, epoch_vld_recall_c = np.mean(epoch_vld_recall_g), np.mean(epoch_vld_recall_v), np.mean(epoch_vld_recall_c)
    
    entry = {
        'training_loss': np.mean(epoch_trn_loss),
        'validation_loss': np.mean(epoch_vld_loss),
        'GRAPHEME_Recall': epoch_vld_recall_g,
        'VOWEL_Recall': epoch_vld_recall_v,
        'CONSONANT_Recall': epoch_vld_recall_c,
        'Final_Recall': np.average([epoch_vld_recall_g, epoch_vld_recall_v, epoch_vld_recall_c], weights=[2, 1, 1]),
    }
    
    utils.display_progress(N_EPOCHS, i+1, entry, postfix='Epochs', persist=True)
    
    # ----------------------------------
    # save model
    if entry['validation_loss'] < np.nan_to_num(logger.log['validation_loss'].min(), nan=100.):
        print('Saving new best weight.')
        torch.save(
            {
                'epoch': i,
                'model': classifier.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, 
            os.path.join('./outputs/', checkpoint_name.format(i)),
        )
    
    # ----------------------------------
    # log
    logger.enter(entry)


8/2690 batches: training_loss: 7.604

KeyboardInterrupt: 