In [1]:
import numpy as np
import pandas as pd
import bloscpack as bp
from sklearn.model_selection import StratifiedKFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

from sklearn.metrics import recall_score

import imgaug as ia
import imgaug.augmenters as iaa

from torch.utils.data.dataloader import DataLoader
from torch.nn.utils import clip_grad_value_

from optim import Over9000

from data import Bengaliai_DS, Balanced_Sampler
from model import *
from model_utils import *
from mixup_utils import *
import utils

import cv2
cv2.setNumThreads(1)

---

In [2]:
SEED = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

---
### data

#### stratification

In [3]:
pdf = pd.read_csv('../input/train.csv')

unique_grapheme = pdf['grapheme'].unique()
grapheme_code = dict([(g, c) for g, c in zip(unique_grapheme, np.arange(unique_grapheme.shape[0]))])
pdf['grapheme_code'] = [grapheme_code[g] for g in pdf['grapheme']]

# skf = MultilabelStratifiedKFold(n_splits=4, shuffle=True, random_state=42)
# for fold, (trn_ndx, vld_ndx) in enumerate(skf.split(pdf['image_id'].values.reshape(-1, 1), pdf.loc[:, ['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']].values.reshape(-1, 3))):
#     if fold == 0:
#         break

skf = StratifiedKFold(n_splits=8, shuffle=True, random_state=42)
for trn_ndx, vld_ndx in skf.split(pdf['grapheme_code'], pdf['grapheme_code']):
    break
    
trn_pdf = pdf.iloc[trn_ndx, :]
trn_pdf.reset_index(inplace=True, drop=True)

vld_pdf = pdf.iloc[vld_ndx, :]
vld_pdf.reset_index(inplace=True, drop=True)

#### augmentation

In [4]:
augs = iaa.SomeOf(
    (0, 2),
    [
        iaa.SomeOf(
            (1, 2),
            [
                iaa.OneOf(
                    [
                        iaa.Affine(scale={"x": (0.8, 1.), "y": (0.8, 1.)}, rotate=(-15, 15), shear=(-15, 15)),
                        iaa.PerspectiveTransform(scale=.08, keep_size=True),
                    ]
                ),
                iaa.PiecewiseAffine(scale=(0.02, 0.03)),
            ],
            random_order=True
        ),
        iaa.OneOf(
            [
                iaa.DirectedEdgeDetect(alpha=(.6, .8), direction=(0.0, 1.0)),
                iaa.Emboss(alpha=(.5, 1.), strength=(.1, 4)),
            ]
        ),
    ],
    random_order=True
)

#### dataset, dataloader with custom sampler

In [5]:
# sampler = Balanced_Sampler(trn_pdf, count_column='image_id', primary_group='grapheme_root', secondary_group=['vowel_diacritic', 'consonant_diacritic'], size=trn_imgs.shape[0])

training_set = Bengaliai_DS(pdf=trn_pdf, transform=augs, split_label=True)
validation_set = Bengaliai_DS(pdf=vld_pdf, split_label=True)

batch_size = 128

training_loader = DataLoader(training_set, batch_size=batch_size, num_workers=6, shuffle=True) # sampler=sampler  shuffle=True
validation_loader = DataLoader(validation_set, batch_size=batch_size, num_workers=6, shuffle=False)

---
### model

In [6]:
DEVICE = 'cuda:0' # 'cuda:0' 'cpu' 
N_GRAPHEME = 168
N_VOWEL = 11
N_CONSONANT = 7
N_TOTAL = N_GRAPHEME + N_VOWEL + N_CONSONANT

N_EPOCHS = 32

checkpoint_name = 'seresnext50_purepytorch_cutmix_aug_epoch{:d}.pth'

In [7]:
predictor = PretrainedCNN(out_dim=N_TOTAL)
classifier = BengaliClassifier(predictor)
classifier = classifier.to(DEVICE)

optimizer = Over9000(classifier.parameters())

In [8]:
logger = utils.csv_logger(['training_loss', 'validation_loss', 'GRAPHEME_Recall', 'VOWEL_Recall', 'CONSONANT_Recall', 'Final_Recall'])

In [9]:
for i in range(N_EPOCHS):
    logger.new_epoch()
    # train
    classifier.train()
    
    epoch_trn_loss = []
    epoch_vld_loss = []
    epoch_vld_metrics = []
    
    for j, (trn_imgs_batch, trn_lbls_batch) in enumerate(training_loader):
        
        optimizer.zero_grad()
        
        # move to device
        trn_imgs_batch_device = trn_imgs_batch.to(DEVICE)
        trn_lbls_batch_device = [l.to(DEVICE) for l in trn_lbls_batch]
        
        # one hot encoding label then perform mixup
        trn_lbls_onehot = [to_onehot(l, c) for l, c in zip(trn_lbls_batch_device, (N_GRAPHEME, N_VOWEL, N_CONSONANT))]
        with torch.no_grad():
            #trn_imgs_batch_mixup, trn_lbls_onehot_mixup = mixup(trn_imgs_batch_device, trn_lbls_onehot)
            trn_imgs_batch_mixup, trn_lbls_onehot_mixup = cutmix(trn_imgs_batch_device, *trn_lbls_onehot, alpha=.4)
        
        # forward pass
        logits = classifier(trn_imgs_batch_mixup)
        # probabilities = logit_to_probability(logits)
        
        # loss and gradient and optmize
        #loss = criterion(logits, trn_lbls_onehot_mixup)
        #total_loss = .7 * loss[0] + .1 * loss[1] + .2 * loss[2]
        #total_loss = loss[0] + loss[1] + loss[2]
        total_loss = cutmix_criterion(*logits, trn_lbls_onehot_mixup)
        total_loss.backward()
        clip_grad_value_(classifier.parameters(), 1.0)
        optimizer.step()
        
        # record
        epoch_trn_loss.append(total_loss.item())
        
        utils.display_progress(len(training_loader), j+1, {'training_loss': epoch_trn_loss[-1]})
    
    # validation
    classifier.eval()
    
    with torch.no_grad():
        for k, (vld_imgs_batch, vld_lbls_batch) in enumerate(validation_loader):
            
            # move to device
            vld_imgs_batch_device = vld_imgs_batch.to(DEVICE)
            vld_lbls_batch_device = [l.to(DEVICE) for l in vld_lbls_batch]
            
            # one hot encoding label then perform mixup
            #vld_lbls_onehot = [to_onehot(l, c) for l, c in zip(vld_lbls_batch_device, (N_GRAPHEME, N_VOWEL, N_CONSONANT))]
            #vld_imgs_batch_mixup, vld_lbls_onehot_mixup = mixup(vld_imgs_batch_device, vld_lbls_onehot)
            
            # forward pass
            #logits = classifier(vld_imgs_batch_mixup)
            logits = classifier(vld_imgs_batch_device)
            probabilities = logit_to_probability(logits)
            batch_vld_results = [p.argmax(axis=1) for p in probabilities]
            batch_vld_metric = [recall_score(l.detach().cpu().numpy(), p.detach().cpu().numpy(), average='macro') for l, p in zip(vld_lbls_batch, batch_vld_results)]
            epoch_vld_metrics.append(batch_vld_metric)
            
            # loss
            #loss = criterion(logits, vld_lbls_onehot_mixup)
            loss = [F.cross_entropy(l0, l1) for l0, l1 in zip(logits, vld_lbls_batch_device)]
            #total_loss = .7 * loss[0] + .1 * loss[1] + .2 * loss[2]
            total_loss = loss[0] + loss[1] + loss[2]
            
            # record
            epoch_vld_loss.append(total_loss.item())
            
            utils.display_progress(len(validation_loader), k+1, {'validation_loss': epoch_vld_loss[-1]})
    #break
    Recalls = np.array(epoch_vld_metrics)
    entry = {
        'training_loss': np.mean(epoch_trn_loss),
        'validation_loss': np.mean(epoch_vld_loss),
        'GRAPHEME_Recall': Recalls[:, 0].mean(),
        'VOWEL_Recall': Recalls[:, 1].mean(),
        'CONSONANT_Recall': Recalls[:, 2].mean(),
        'Final_Recall': np.average([Recalls[:, 0].mean(), Recalls[:, 1].mean(), Recalls[:, 2].mean()], weights=[2, 1, 1]),
    }
    utils.display_progress(N_EPOCHS, i+1, entry, postfix='Epochs', persist=True)
    # ----------------------------------
    # save model
    if entry['validation_loss'] < np.nan_to_num(logger.log['validation_loss'].min(), nan=100.):
        print('Saving new best weight.')
        torch.save(
            {
                'epoch': i,
                'model': classifier.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, 
            os.path.join('./outputs/', checkpoint_name.format(i)),
        )
    
    # ----------------------------------
    # log
    logger.enter(entry)


4/1373 batches: validation_loss: 1.9750

  _warn_prf(average, modifier, msg_start, len(result))


1373/1373 batches: validation_loss: 2.370

1/32 Epochs: training_loss: 2.987; validation_loss: 2.117; GRAPHEME_Recall: 0.816; VOWEL_Recall: 0.859; CONSONANT_Recall: 0.859; Final_Recall: 0.837
Saving new best weight.
15/1373 batches: training_loss: 1.252

Traceback (most recent call last):
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
Traceback (most recent call last):
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/yuan/miniconda3/envs/ML/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
 

KeyboardInterrupt: 