In [1]:
import numpy as np
import pandas as pd
import bloscpack as bp

from sklearn.model_selection import StratifiedKFold

import imgaug as ia
import imgaug.augmenters as iaa

from torch.utils.data.dataloader import DataLoader

import fastai
from fastai.vision import *

from optim import Over9000
from data import Bengaliai_DS, Bengaliai_DS_LIT
from model import *
from model_utils import *

---

In [2]:
SEED = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

---
### data

In [3]:
pdf = pd.read_csv('../input/train.csv')
unique_grapheme = pdf['grapheme'].unique()
grapheme_code = dict([(g, c) for g, c in zip(unique_grapheme, np.arange(unique_grapheme.shape[0]))])
pdf['grapheme_code'] = [grapheme_code[g] for g in pdf['grapheme']]

skf = StratifiedKFold(n_splits=7, shuffle=True, random_state=42)
for trn_ndx, vld_ndx in skf.split(pdf['grapheme_code'], pdf['grapheme_code']):
    break
    
imgs = bp.unpack_ndarray_from_file('../features/train_images_size160_pad8.bloscpack')
lbls = pd.read_csv('../input/train.csv').iloc[:, 1:4].values

trn_imgs = imgs[trn_ndx]
trn_lbls = lbls[trn_ndx]
vld_imgs = imgs[vld_ndx]
vld_lbls = lbls[vld_ndx]

In [4]:
augs = iaa.SomeOf(
    (0, 2),
    [
        iaa.SomeOf(
            (1, 2),
            [
                iaa.OneOf(
                    [
                        iaa.Affine(scale={"x": (0.8, 1.), "y": (0.8, 1.)}, rotate=(-15, 15), shear=(-15, 15)),
                        iaa.PerspectiveTransform(scale=.08, keep_size=True),
                    ]
                ),
                iaa.PiecewiseAffine(scale=(0.02, 0.03)),
            ],
            random_order=True
        ),
        iaa.DirectedEdgeDetect(alpha=(.6, .8), direction=(0.0, 1.0)),
    ],
    random_order=True
)


In [5]:
batch_size = 64 # 64 is important as the fit_one_cycle arguments are probably tuned for this batch size

training_set = Bengaliai_DS(trn_imgs, trn_lbls, transform=augs)
validation_set = Bengaliai_DS(vld_imgs, vld_lbls)

training_loader = DataLoader(training_set, batch_size=batch_size, num_workers=6, shuffle=True) # , sampler=sampler , shuffle=True
validation_loader = DataLoader(validation_set, batch_size=batch_size, num_workers=6, shuffle=False)

data_bunch = DataBunch(train_dl=training_loader, valid_dl=validation_loader)

---
### model

In [6]:
device = 'cuda:0'
n_grapheme = 168
n_vowel = 11
n_consonant = 7
n_total = n_grapheme + n_vowel + n_consonant

In [7]:
predictor = PretrainedCNN(out_dim=n_total)
classifier = BengaliClassifier(predictor)

In [8]:
learn = Learner(
    data_bunch,
    classifier,
    loss_func=Loss_combine_weighted(),
    opt_func=Over9000,
    metrics=[Metric_grapheme(), Metric_vowel(), Metric_consonant(), Metric_tot()]
)

logger = CSVLogger(learn, 'Seresnext50_Size160_Augs_Mixup_64epochs')

learn.clip_grad = 1.0
learn.split([classifier.predictor.lin_layers])
# learn.split([classifier.head1])
learn.unfreeze()

In [None]:
learn.fit_one_cycle(
    64,
    max_lr=slice(0.2e-2, 1e-2),
    wd=[1e-3, 0.1e-1],
    pct_start=0.0,
    div_factor=100,
    callbacks=[logger, SaveModelCallback(learn, monitor='metric_tot', mode='max', name='Seresnext50_Size160_Augs_Mixup_64epochs'), MixUpCallback(learn)] # 
)

epoch,train_loss,valid_loss,metric_idx,metric_idx.1,metric_idx.2,metric_tot,time
0,1.331951,0.306744,0.894299,0.954005,0.903045,0.911412,12:01
1,1.178014,0.274872,0.907146,0.956241,0.945279,0.928953,12:03
2,1.035775,0.179457,0.937529,0.975922,0.967125,0.954526,11:58
3,0.95871,0.184811,0.941107,0.974545,0.965831,0.955647,12:01
4,0.910576,0.161141,0.946655,0.976314,0.965746,0.958843,12:05
5,0.847438,0.144779,0.95099,0.977922,0.976147,0.964012,12:00
6,0.811921,0.139046,0.953087,0.981497,0.970203,0.964468,12:03
7,0.791004,0.143073,0.951706,0.979025,0.969766,0.963051,11:59
8,0.780517,0.133677,0.957952,0.982383,0.968751,0.966759,12:01
9,0.721739,0.131792,0.95527,0.981888,0.97477,0.966799,12:02
