In [None]:
!nvidia-smi

In [None]:
from PIL import Image
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler, TensorDataset
from sklearn.model_selection import train_test_split
from tqdm import notebook
import torchvision
from torchvision import transforms
from torchvision import models
from collections import Counter
from pathlib import Path
from sklearn.metrics import confusion_matrix, recall_score
import seaborn as sn
!pip install torchsummary 
import torchsummary
!pip install torch-lr-finder
from torch_lr_finder import LRFinder
import copy
import math
import random
from PIL.Image import BICUBIC
import json

In [None]:
print(*torch.__config__.show().split("\n"), sep="\n")

In [None]:
torch.get_num_threads()

In [None]:
torch.set_num_threads(2 if torch.cuda.is_available() else 4)

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
df_train = pd.read_csv('/kaggle/input/bengaliai-cv19/train.csv')
df_test = pd.read_csv('/kaggle/input/bengaliai-cv19/test.csv')
df_class = pd.read_csv('/kaggle/input/bengaliai-cv19/class_map.csv')
df_submission = pd.read_csv('/kaggle/input/bengaliai-cv19/sample_submission.csv')

In [None]:
def make_tensordataset_from_npys(npy_locs, ids_file, label_loc=None):
    with open(ids_file) as f:
        ids = json.load(f)
    X = []
    npy_locs = notebook.tqdm(npy_locs)
    for npy_loc in npy_locs:
        x = np.load(npy_loc)
        X.append(x)
    X = np.vstack(X)
    X = X.reshape(-1, 1, 137, 236)
    X = torch.from_numpy(X)
    ids = dict((s,i) for (i,s) in enumerate(ids))
    if label_loc is None:
        return TensorDataset(X)
    else:
        graphemes = torch.zeros(X.shape[0], dtype=torch.long)
        vowel_diacs = torch.zeros(X.shape[0], dtype=torch.long)
        consonant_diacs = torch.zeros(X.shape[0], dtype=torch.long)
        lbl_df = pd.read_csv(label_loc)
        for row in lbl_df.itertuples():
            if row.image_id not in ids:
                continue
            idx = ids[row.image_id]
            graphemes[idx] = row.grapheme_root
            vowel_diacs[idx] = row.vowel_diacritic
            consonant_diacs[idx] = row.consonant_diacritic
        return TensorDataset(X, graphemes, vowel_diacs, consonant_diacs)

In [None]:
def make_tensordataset_from_dfs(parquet_locs, label_loc=None):
    ids = []
    X = []
    parquet_locs = notebook.tqdm(parquet_locs)
    i = 0
    for parquet_loc in parquet_locs:
        i += 1
        df = pd.read_parquet(parquet_loc)
        ids.extend(df.image_id.tolist())
        x = df.iloc[:, 1:].to_numpy(dtype=np.uint8)
        del df
        X.append(x)
    X = np.vstack(X)
    X = X.reshape(-1, 1, 137, 236)
    X = torch.from_numpy(X)
    ids = dict((s,i) for (i,s) in enumerate(ids))
    if label_loc is None:
        return TensorDataset(X)
    else:
        graphemes = torch.zeros(X.shape[0], dtype=torch.long)
        vowel_diacs = torch.zeros(X.shape[0], dtype=torch.long)
        consonant_diacs = torch.zeros(X.shape[0], dtype=torch.long)
        lbl_df = pd.read_csv(label_loc)
        for row in lbl_df.itertuples():
            if row.image_id not in ids:
                continue
            idx = ids[row.image_id]
            graphemes[idx] = row.grapheme_root
            vowel_diacs[idx] = row.vowel_diacritic
            consonant_diacs[idx] = row.consonant_diacritic
        return TensorDataset(X, graphemes, vowel_diacs, consonant_diacs)

In [None]:
# ds = make_tensordataset_from_dfs(
#     ['/kaggle/input/bengaliai-cv19/train_image_data_{}.parquet'.format(i) for i in range(4)], 
#     '/kaggle/input/bengaliai-cv19/train.csv')

In [None]:
ds = make_tensordataset_from_npys(
    ['/kaggle/input/bangla-grapheme-npy/tr-ds-{}.npy'.format(i) for i in range(1, 5)],
    '/kaggle/input/bangla-grapheme-npy/tr-ds-ids.json',
    '/kaggle/input/bengaliai-cv19/train.csv')

In [None]:
idx = random.randrange(len(ds))
plt.imshow(ds[idx][0].permute(1, 2, 0).reshape(137, 236), cmap='gray', vmin=0, vmax=255)
ds[idx][1], ds[idx][2], ds[idx][3], ds[idx][0].max(), ds[idx][0].min(), idx

In [None]:
# tmp = torch.nonzero(ds[100000][0] < 100)

In [None]:
# tmp[:, 1].max(), tmp[:, 1].min(), tmp[:, 2].max(), tmp[:, 2].min()

In [None]:
tr_indices, va_indices = train_test_split(
    list(range(len(ds))), 
    test_size=.1, 
    train_size=.9, 
    random_state=42
)#, stratify=ds.tensors[1])

In [None]:
class TensorWithImageTransforms(Dataset):
    def __init__(self, tensor_dataset, transforms, p=0.0):
        super(TensorWithImageTransforms, self).__init__()
        self.ds = tensor_dataset
        self.tr = transforms
        self.nt = len(self.ds[0])
        self.ln = len(self.ds)
        self.p = p
        
    def __getitem__(self, index):
        img = self.ds[index][0]
        if random.random() > self.p:
            img = self.tr(img)
            img *= 255.0
        else:
            img = img.float()
        return (img,) + self.ds[index][1:]
    
    def __len__(self):
        return self.ln

In [None]:
tr_ds = Subset(ds, tr_indices)
va_ds = Subset(ds, va_indices)

In [None]:
tfms = transforms.Compose([
    transforms.ToPILImage(mode='L'),
    transforms.Pad((64, 13), padding_mode='reflect'),
    transforms.RandomAffine(degrees=10.0, translate=(0.15, 0.05), scale=(0.90, 1.05), resample=BICUBIC, fillcolor=255),
    transforms.CenterCrop((137, 236)),
    transforms.ToTensor(),
])

In [None]:
tr_ds_tfms = TensorWithImageTransforms(tr_ds, tfms, p=0.0)

In [None]:
idx = random.randrange(len(tr_ds_tfms))

In [None]:
im = tr_ds_tfms[idx][0].permute(1, 2, 0).reshape(137, 236)
plt.imshow(im, cmap='gray', vmin=0., vmax=255.)
g, v, c = map(lambda t: t.item(), tr_ds_tfms[idx][1:])
print(df_class[(df_class['label'] == g) & (df_class['component_type'] == 'grapheme_root')]['component'])
print(df_class[(df_class['label'] == v) & (df_class['component_type'] == 'vowel_diacritic')]['component'])
print(df_class[(df_class['label'] == c) & (df_class['component_type'] == 'consonant_diacritic')]['component'])
im.max(), im.min(), idx

In [None]:
tr_ds = tr_ds_tfms

In [None]:
len(ds), len(tr_ds), len(va_ds)

In [None]:
n_graphemes, n_vowel_diacs, n_consonant_diacs = len(set(df_train['grapheme_root'])), len(set(df_train['vowel_diacritic'])), len(set(df_train['consonant_diacritic']))

In [None]:
models.resnet152()

In [None]:
torchsummary.summary(models.resnet152().to(device), input_size=(3,137,236), batch_size=32)

In [None]:
def make_linear_block(in_size, out_size):
    block = nn.Sequential(
        nn.Linear(in_size, out_size), 
        nn.ReLU(), 
        nn.BatchNorm1d(num_features=out_size),
    )
    nn.init.xavier_normal_(block[0].weight.data)
    nn.init.zeros_(block[0].bias.data)
    return block

def make_ff_predictor(in_size, intermediate_size, out_size, layer_count):
    layers = [make_linear_block(in_size, intermediate_size)]
    for i in range(layer_count):
        layers.append(make_linear_block(intermediate_size, intermediate_size))
    layers.append(make_linear_block(intermediate_size, out_size))
    layers = nn.Sequential(*layers)
    return layers

def make_squeeze_predictor(in_size, out_size):
    return nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(in_size, out_size, kernel_size=(1, 1)),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(1,1)),
            nn.Flatten()
        )

class BanglaHandwrittenGraphemeNN(nn.Module):
    def __init__(self):
        super(BanglaHandwrittenGraphemeNN, self).__init__()
#         base = models.squeezenet1_0(pretrained=True).features
        base = models.resnet152(pretrained=True)
        base.fc = nn.Identity()
#         base = models.densenet121(pretrained=True)
#         base.classifier = nn.Identity()
        self.base = base
        feature_size = 2048
        self.grapheme_predictor = make_ff_predictor(feature_size, 512, n_graphemes, 2)
        self.vowel_diac_predictor = make_ff_predictor(feature_size, 512, n_vowel_diacs, 1)
        self.consonant_diacs = make_ff_predictor(feature_size, 512, n_consonant_diacs, 1)

    def convert_to_grayscale(self):
        with torch.no_grad():
            conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            conv1.weight.data = torch.sum(self.base.conv1.weight.data, dim=1, keepdim=True)
            self.base.conv1 = conv1
#             conv1.weight.data = torch.sum(self.base.features.conv0.weight.data, dim=1, keepdim=True)
#             self.base.features.conv0 = conv1
            
    def freeze(self):
        for p in self.base.parameters():
            p.requires_grad = False
        
    def unfreeze(self):
        for p in self.base.parameters():
            p.requires_grad = True
        
    def forward(self, x):
        features = self.base(x)
        g_pred = self.grapheme_predictor(features)
        v_pred = self.vowel_diac_predictor(features)
        c_pred = self.consonant_diacs(features)
        return g_pred, v_pred, c_pred

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=1.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        logits = nn.functional.log_softmax(inputs)
        CE_loss = nn.functional.nll_loss(logits, targets, reduction='none')
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

In [None]:
class MultiTaskLoss(nn.Module):
    def __init__(self, num_tasks, init_weight=None):
        super(MultiTaskLoss, self).__init__()
        self.n = num_tasks
        if init_weight is None:
            self.w = nn.Parameter(torch.zeros(self.n))
        else:
            self.w = nn.Parameter(torch.tensor(init_weight))
            
    def freeze(self):
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, losses):
        return torch.sum(torch.exp(-2.0 * self.w) * losses) + torch.sum(self.w)

In [None]:
def train_multi_task_batch(model, optimizer, criterions, mtl_criterion, device, batch):
    img, g, v, c = batch
    img, g, v, c = img.to(device), g.to(device), v.to(device), c.to(device)
    img = img / 255.0
    g_criterion, v_criterion, c_criterion = criterions
    optimizer.zero_grad()
    g_pred, v_pred, c_pred = model(img)
    g_loss = g_criterion(g_pred, g)
    v_loss = v_criterion(v_pred, v)
    c_loss = c_criterion(c_pred, c)
    loss = mtl_criterion(torch.stack((g_loss, v_loss, c_loss)))
    loss.backward()
    optimizer.step()
    return loss.item(), g_loss.item(), v_loss.item(), c_loss.item()

In [None]:
def validate_multi_task_batch(model, criterions, mtl_criterion, device, batch):
    with torch.no_grad():
        img, g, v, c = batch
        img, g, v, c = img.to(device), g.to(device), v.to(device), c.to(device)
        img = img / 255.0
        g_pred, v_pred, c_pred = model(img)
        g_criterion, v_criterion, c_criterion = criterions
        g_loss = g_criterion(g_pred, g)
        v_loss = v_criterion(v_pred, v)
        c_loss = c_criterion(c_pred, c)
        loss = mtl_criterion(torch.stack((g_loss, v_loss, c_loss)))
        losses = (loss.item(), g_loss.item(), v_loss.item(), c_loss.item())
        preds = (g_pred.argmax(1).tolist(), v_pred.argmax(1).tolist(), c_pred.argmax(1).tolist())
        trues = (g.tolist(), v.tolist(), c.tolist())
        return losses, preds, trues

In [None]:
# LRFinder
def lr_finder(model, optimizer, criterions, mtl_criterion, device, dl, num_iter=10, start_lr=1e-4, end_lr=1.0):
    model_state =  copy.deepcopy(model.state_dict())
    optim_state =  copy.deepcopy(optimizer.state_dict())
    mtl_criterion_state = copy.deepcopy(mtl_criterion.state_dict())
    for param in optimizer.param_groups:
        param['lr'] = start_lr
    gamma = (end_lr / start_lr) ** (1 / num_iter)
    print(gamma, start_lr, end_lr)
    lrf_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    count = 0
    lrf_losses = []
    lrs = []
    min_loss = math.inf
    done = False
    while not done:
        batches = dl
        batches = notebook.tqdm(tr_dl)
        for batch in batches:
            count += 1
            losses = train_multi_task_batch(model, optimizer, criterions, mtl_criterion, device, batch)
            print(losses)
            min_loss = min(min_loss, losses[0])
            lrf_losses.append(losses)
            lrf_sched.step()
            lrs.append([pg['lr'] for pg in optimizer.param_groups])
            if num_iter == count:# or losses[0] /10.0 > min_loss:
                done = True
                print(min_loss, losses[0] / 10.0)
                break
    model.load_state_dict(model_state)
    optimizer.load_state_dict(optim_state)
    mtl_criterion.load_state_dict(mtl_criterion_state)
    return lrf_losses, lrs

In [None]:
model = BanglaHandwrittenGraphemeNN().to(device)

In [None]:
model.convert_to_grayscale()

In [None]:
model.load_state_dict(torch.load('/kaggle/input/bangla-handwritten-grapheme-uncertainty-weighted/model.pth', map_location=device))

In [None]:
model.freeze()

In [None]:
g_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=1.0)
v_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=1.0)
c_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=1.0)

In [None]:
criterions = (g_criterion, v_criterion, c_criterion)

In [None]:
mtl_criterion = MultiTaskLoss(3).to(device)

In [None]:
mtl_criterion.load_state_dict(torch.load('/kaggle/input/bangla-handwritten-grapheme-uncertainty-weighted/mtlc.pth', map_location=device))

In [None]:
# mtl_criterion.freeze()

In [None]:
optimizer = optim.Adam([{'params': model.parameters()}, {'params': mtl_criterion.parameters()}])
# optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
optimizer.load_state_dict(torch.load('/kaggle/input/bangla-handwritten-grapheme-uncertainty-weighted/optim.pth', map_location=device))

In [None]:
mean = lambda l: sum(l) / len(l)

In [None]:
tr_dl = DataLoader(tr_ds, batch_size=32, num_workers=2, pin_memory=True, shuffle=True, drop_last=True)
va_dl = DataLoader(va_ds, batch_size=32, num_workers=2, pin_memory=True)

In [None]:
# lrf_losses, lrs = lr_finder(model, optimizer, criterions, mtl_criterion, device, tr_dl, num_iter=100, start_lr=1e-5, end_lr=1e0)

In [None]:
# !nvidia-smi

In [None]:
# skip_first = 1
# skip_last = 10

In [None]:
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]])

In [None]:
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]], [t[0] for t in lrf_losses][skip_first:-skip_last])
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]], [t[1] for t in lrf_losses][skip_first:-skip_last])
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]], [t[2] for t in lrf_losses][skip_first:-skip_last])
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]], [t[3] for t in lrf_losses][skip_first:-skip_last])
# plt.xscale('log')

In [None]:
# lr = lrs[32][0]
# lr

In [None]:
lr = 1e-4

In [None]:
# for param in optimizer.param_groups:
#         param['lr'] = lr

In [None]:
tr_losses = []
va_losses = []
va_scores = []
mtl_weights = []

In [None]:
num_epochs = 5
steps_per_epoch = len(tr_dl)

In [None]:
# scheduler = OneCycleLR(optimizer, lr, epochs=num_epochs, steps_per_epoch=steps_per_epoch)

In [None]:
epochs = range(num_epochs)
# epochs = notebook.tqdm(range(num_epochs))
for epoch in epochs:
    model.train()
    count = 0
    batches = tr_dl
#     batches = notebook.tqdm(tr_dl)
    for batch in batches:
        losses = train_multi_task_batch(model, optimizer, criterions, mtl_criterion, device, batch)
#         batches.set_description("{:0.4f} {:0.4f} {:0.4f} {:0.4f}".format(*losses))
#         scheduler.step()
        tr_losses.append(losses)
        mtl_weights.append(mtl_criterion.w.tolist())
    print(mean([t[0] for t in tr_losses]), 
          mean([t[1] for t in tr_losses]), 
          mean([t[2] for t in tr_losses]), 
          mean([t[3] for t in tr_losses]))
    
    model.eval()
    va_batch_losses = []
    
    va_g_preds = []
    va_v_preds = []
    va_c_preds = []
    
    va_g_trues = []
    va_v_trues = []
    va_c_trues = []
    batches = va_dl
#     batches = notebook.tqdm(va_dl)
    for batch in batches:
        losses, preds, trues = validate_multi_task_batch(model, criterions, mtl_criterion, device, batch)
        va_batch_losses.append(losses)

        g_pred, v_pred, c_pred = preds
        g_true, v_true, c_true = trues

        va_g_trues.extend(g_true)
        va_v_trues.extend(v_true)
        va_c_trues.extend(c_true)

        va_g_preds.extend(g_pred)
        va_v_preds.extend(v_pred)
        va_c_preds.extend(c_pred)

    avg_loss = mean([t[0] for t in va_batch_losses])
    avg_g_loss = mean([t[1] for t in va_batch_losses])
    avg_v_loss = mean([t[2] for t in va_batch_losses])
    avg_c_loss = mean([t[3] for t in va_batch_losses])
    va_losses.append((avg_loss, avg_g_loss, avg_v_loss, avg_c_loss))
    g_rec = recall_score(va_g_trues, va_g_preds, average='macro')
    v_rec = recall_score(va_v_trues, va_v_preds, average='macro')
    c_rec = recall_score(va_c_trues, va_c_preds, average='macro')
    score = 0.5 * g_rec + 0.25 * v_rec + 0.25 * c_rec
    va_scores.append((score, g_rec, v_rec, c_rec))
    print(va_losses[-1])
    print(va_scores[-1])
    print(confusion_matrix(va_v_trues, va_v_preds))
    print(confusion_matrix(va_c_trues, va_c_preds))
    plt.figure(figsize = (20, 20))
    sn.heatmap(np.log1p(confusion_matrix(va_g_trues, va_g_preds)))
    plt.show()

In [None]:
!nvidia-smi

In [None]:
plt.plot([t[0] for t in tr_losses])

In [None]:
plt.plot([t[1] for t in tr_losses])

In [None]:
plt.plot([t[2] for t in tr_losses])

In [None]:
plt.plot([t[3] for t in tr_losses])

In [None]:
mtl_weights[-1]

In [None]:
plt.plot(mtl_weights)

In [None]:
plt.plot([t[0] for t in mtl_weights])

In [None]:
plt.plot([t[1] for t in mtl_weights])

In [None]:
plt.plot([t[2] for t in mtl_weights])

In [None]:
plt.plot(va_scores)

In [None]:
plt.plot([t[0] for t in va_losses])

In [None]:
plt.plot([t[1] for t in va_losses])

In [None]:
plt.plot([t[2] for t in va_losses])

In [None]:
plt.plot([t[3] for t in va_losses])

In [None]:
torch.save(model.state_dict(), 'model.pth')
torch.save(optimizer.state_dict(), 'optim.pth')
torch.save(mtl_criterion.state_dict(), 'mtlc.pth')