In [None]:
mean = lambda l: sum(l) / len(l)

In [None]:
import os
from PIL import Image
from skimage.filters import threshold_otsu
from skimage.transform import AffineTransform, SimilarityTransform, warp, resize
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler, TensorDataset
from sklearn.model_selection import train_test_split
from tqdm import notebook
import torchvision
from torchvision import transforms
from torchvision import models
from collections import Counter
from pathlib import Path
from sklearn.metrics import confusion_matrix, recall_score
import seaborn as sn
!pip install torchsummary
import torchsummary
# !pip install torch-lr-finder
# from torch_lr_finder import LRFinder
import copy
import math
import random
from PIL.Image import BICUBIC
import json
!pip install pretrainedmodels
import pretrainedmodels
!pip install iterative-stratification
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
!pip install torchtoolbox
from torchtoolbox.tools import mixup_data, mixup_criterion

In [None]:
print(*torch.__config__.show().split("\n"), sep="\n")

In [None]:
torch.set_num_threads(2 if torch.cuda.is_available() else 4)

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
df_train = pd.read_csv('/kaggle/input/bengaliai-cv19/train.csv')
df_test = pd.read_csv('/kaggle/input/bengaliai-cv19/test.csv')
df_class = pd.read_csv('/kaggle/input/bengaliai-cv19/class_map.csv')
df_submission = pd.read_csv('/kaggle/input/bengaliai-cv19/sample_submission.csv')
df_auxtask_label = pd.read_csv('/kaggle/input/bhgd-aux-tasks/tasks.csv', names=['id', 'matra', 'up', 'conj', 'fg', 'sg'], skiprows=1)

In [None]:
def patch_label_with_corrected_label(df, cl):
    for row in cl.itertuples(name=None):
        df.at[row[1], 'grapheme_root'] = row[2]
        df.at[row[1], 'vowel_diacritic'] = row[3]
        df.at[row[1], 'consonant_diacritic'] = row[4]

In [None]:
df_corrected_labels = pd.read_csv('/kaggle/input/bhgd-corrected-labels/corrected_labels.csv', names=['index', 'g', 'v', 'c'])

In [None]:
len(df_corrected_labels)

In [None]:
df_corrected_labels.head()

In [None]:
# patch_label_with_corrected_label(df_train, df_corrected_labels)

In [None]:
n_graphemes, n_vowel_diacs, n_consonant_diacs = len(set(df_train['grapheme_root'])), len(set(df_train['vowel_diacritic'])), len(set(df_train['consonant_diacritic']))

In [None]:
def make_tensordataset_from_npys(npy_locs, ids_file, lbl_df=None):
    with open(ids_file) as f:
        ids = json.load(f)
    X = []
    npy_locs = notebook.tqdm(npy_locs)
    for npy_loc in npy_locs:
        x = np.load(npy_loc)
        X.append(x)
    X = np.vstack(X)
    X = X.reshape(-1, 1, 137, 236)
    X = torch.from_numpy(X)
    ids = dict((s,i) for (i,s) in enumerate(ids))
    if lbl_df is None:
        return TensorDataset(X)
    else:
        graphemes = torch.zeros(X.shape[0], dtype=torch.long)
        vowel_diacs = torch.zeros(X.shape[0], dtype=torch.long)
        consonant_diacs = torch.zeros(X.shape[0], dtype=torch.long)
        for row in lbl_df.itertuples():
            if row.image_id not in ids:
                continue
            idx = ids[row.image_id]
            graphemes[idx] = row.grapheme_root
            vowel_diacs[idx] = row.vowel_diacritic
            consonant_diacs[idx] = row.consonant_diacritic
        return TensorDataset(X, graphemes, vowel_diacs, consonant_diacs)

In [None]:
# def make_tensordataset_from_dfs(parquet_locs, label_loc=None):
#     ids = []
#     X = []
#     parquet_locs = notebook.tqdm(parquet_locs)
#     for parquet_loc in parquet_locs:
#         df = pd.read_parquet(parquet_loc)
#         ids.extend(df.image_id.tolist())
#         x = df.iloc[:, 1:].to_numpy(dtype=np.uint8)
#         del df
#         X.append(x)
#     X = np.vstack(X)
#     X = X.reshape(-1, 1, 137, 236)
#     X = torch.from_numpy(X)
#     ids = dict((s,i) for (i,s) in enumerate(ids))
#     if label_loc is None:
#         return TensorDataset(X)
#     else:
#         graphemes = torch.zeros(X.shape[0], dtype=torch.long)
#         vowel_diacs = torch.zeros(X.shape[0], dtype=torch.long)
#         consonant_diacs = torch.zeros(X.shape[0], dtype=torch.long)
#         lbl_df = pd.read_csv(label_loc)
#         for row in lbl_df.itertuples():
#             if row.image_id not in ids:
#                 continue
#             idx = ids[row.image_id]
#             graphemes[idx] = row.grapheme_root
#             vowel_diacs[idx] = row.vowel_diacritic
#             consonant_diacs[idx] = row.consonant_diacritic
#         return TensorDataset(X, graphemes, vowel_diacs, consonant_diacs)

In [None]:
# ds = make_tensordataset_from_dfs(
#     ['/kaggle/input/bengaliai-cv19/train_image_data_{}.parquet'.format(i) for i in range(4)], 
#     '/kaggle/input/bengaliai-cv19/train.csv')

In [None]:
!ls /kaggle/input

In [None]:
ds = make_tensordataset_from_npys(
    ['/kaggle/input/bangla-grapheme-npy/tr-ds-{}.npy'.format(i) for i in range(1, 5)],
    '/kaggle/input/bangla-grapheme-npy/tr-ds-ids.json',
    df_train)

In [None]:
detensorify = lambda l : list(map(lambda t:t.item(), l))

In [None]:
labels = [detensorify(ds[i][1:]) for i in range(len(ds))]

In [None]:
labels = np.array(labels)

In [None]:
# class DatasetFromImg(Dataset):
#     def __init__(self, img_dir, df):
#         super(DatasetFromImg, self).__init__()
#         self.img_dir = img_dir
#         self.df = df
        
#     def __getitem__(self, index):
#         img_fname = self.df.iloc[index].image_id + '.png'
#         img = Image.open(self.img_dir + '/' + img_fname)
#         g = self.df.iloc[index].grapheme_root
#         v = self.df.iloc[index].vowel_diacritic
#         c = self.df.iloc[index].consonant_diacritic
#         return img, g, v, c
    
#     def __len__(self):
#         return len(self.df)

In [None]:
# ds = DatasetFromImg('/kaggle/input/grapheme-imgs-128x128', df_train)

In [None]:
class DatasetWithAuxiliaryTasks(Dataset):
    def __init__(self, ds, auxilary_tasks):
        super(DatasetWithAuxiliaryTasks, self).__init__()
        self.auxilary_tasks = auxilary_tasks
        self.ds = ds
        self.ln = len(self.ds)
        
    def __getitem__(self, index):
        img, g, v, c = ds[index]
        aux_labels = [f(g, v, c) for f in self.auxilary_tasks]
        return (img, g, v, c) + tuple(aux_labels)
    
    def __len__(self):
        return self.ln

In [None]:
matra_label = df_auxtask_label.set_index('id')['matra'].to_dict()

In [None]:
fg_label = df_auxtask_label.set_index('id')['fg'].to_dict()
g_to_fg = dict(((k, v) for v, k in  enumerate((set(fg_label.values())))))
fg_label = dict(((k, g_to_fg[v]) for k, v in fg_label.items()))
num_fg = len(g_to_fg.keys())
g_to_fg, num_fg

In [None]:
sg_label = df_auxtask_label.set_index('id')['sg'].to_dict()
g_to_sg = dict(((k, v) for v, k in  enumerate((set(sg_label.values())))))
sg_label = dict(((k, g_to_sg[v]) for k, v in sg_label.items()))
num_sg = len(g_to_sg.keys())
g_to_sg, num_sg

In [None]:
conj_label = df_auxtask_label.set_index('id')['conj'].to_dict()

In [None]:
up_label = df_auxtask_label.set_index('id')['up'].to_dict()

In [None]:
def no_diac_task(g, v, c):
    return 1 if g < 13 else 0
def matra_task(g, v, c):
    return matra_label[g.item()]
def up_task(g, v, c):
    return up_label[g.item()]
def conj_task(g, v, c):
    return conj_label[g.item()]
def fg_task(g, v, c):
    return fg_label[g.item()]
def sg_task(g, v, c):
    return sg_label[g.item()]
aux_tasks = []#[no_diac_task, matra_task, up_task, conj_task, fg_task, sg_task]
ds_with_aux = DatasetWithAuxiliaryTasks(ds, aux_tasks)

In [None]:
class DatasetWithImageTransforms(Dataset):
    def __init__(self, ds, transforms):
        super(DatasetWithImageTransforms, self).__init__()
        self.ds = ds
        self.tr = transforms
        self.nt = len(self.ds[0])
        self.ln = len(self.ds)
        
    def __getitem__(self, index):
        img, *rest = self.ds[index]
        img = self.tr(img)
        return (img,) + tuple(rest)
    
    def __len__(self):
        return self.ln

In [None]:
detensorify = lambda l : list(map(lambda t:t.item(), l))

labels = [detensorify(ds[i][1:]) for i in range(len(ds))]

labels = np.array(labels)

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.01, random_state=42)

tr_indices, va_indices = list(msss.split(list(range(len(ds))), labels))[0]

In [None]:
# tr_indices, va_indices = torch.load('/kaggle/input/bangla-handwritten-grapheme/tr_indices'), torch.load('/kaggle/input/bangla-handwritten-grapheme/va_indices')

In [None]:
torch.save(tr_indices, 'tr_indices')
torch.save(va_indices, 'va_indices')

In [None]:
# tr_indices, va_indices = train_test_split(
#     list(range(len(ds))), 
#     test_size=0.10, 
#     train_size=0.90, 
#     random_state=42,
#     stratify=labels[:,0]
# )

In [None]:
tr_ds = Subset(ds_with_aux, tr_indices)
va_ds = Subset(ds_with_aux, va_indices)

In [None]:
# Returns binary image
def thresh(img):
    thresh_val = int(threshold_otsu(img))
    img = (img > thresh_val)
    return img

# For binary image
def bounding_box(img):
    img = thresh(img).astype(np.uint8)
    # find the min value of each column
    col_min_val = np.min(img, axis=0)
    # find the min value of each row
    row_min_val = np.min(img, axis=1)
    # argwhere finds the non-zero elements we want to find the zero elements (zeros are part of character)
    col = np.argwhere(1 - col_min_val).flatten()
    row = np.argwhere(1 - row_min_val).flatten()
    return row.min(), row.max(), col.min(), col.max()

def scale_to_bb(img):
    height = img.shape[0]
    width = img.shape[1]
    t, b, l, r = bounding_box(img)
    box_width = r - l
    box_height = b - t
    t, l = max(0, t - 10), max(0, l - 10)
    b, r = min(height, b + 10), min(width, r + 10)
#     print(l, r, t, b)
    img = resize(img[t:b, l:r], output_shape=(256, 256), preserve_range=True, order=3, cval=1.0)
    return img

def random_scale(img):
    height = img.shape[0]
    width = img.shape[1]
    t, b, l, r = bounding_box(img)
    box_width = r - l
    box_height = b - t
    max_width_scale = (box_width + min(l, width - r)) / box_width
    max_height_scale = (box_height + min(t, height - b)) / box_height
    max_scale = min(max_width_scale, max_height_scale)
    min_scale = min(1.0, 0.25 * max((height / box_height), (width / box_width)))
    scale = random.uniform(min_scale, max_scale)
    tfm = SimilarityTransform(
        scale=(scale, scale),
    )
    img = warp(img, tfm.inverse, cval=1.0, order=3)
    return img

def random_translate(img):
    height = img.shape[0]
    width = img.shape[1]
    t, b, l, r = bounding_box(img)
    box_width = r - l
    box_height = b - t
    translate_height = random.uniform(-t, height - b)
    translate_width = random.uniform(-l, width - r)
    tfm = SimilarityTransform(
        translation=(translate_width, translate_height),
    )
    img = warp(img, tfm.inverse, cval=1.0, order=3)
    return img

def random_rotate_and_shear(img):
    max_theta = math.pi / 16
    theta = random.uniform(-max_theta, max_theta)
    max_shear_theta = math.pi / 8
    shear_theta = random.uniform(-max_shear_theta, max_shear_theta)
    tfm = AffineTransform(rotation=theta, shear=shear_theta)
    img = warp(img, tfm.inverse, cval=1.0, order=3)
    return img

def invert_color(t):
    t.mul_(-1)
    t.add_(255)
    return t

def affine_transforms(img):
    img = img.reshape(137, 236).numpy()
    img = thresh(img).astype(np.float32)
#     img = random_translate(random_scale(random_rotate_and_shear(img)))
    img = random_translate(random_scale(img))
    img = img.reshape(1, 137, 236)
    return torch.from_numpy(img)

def tfms(img):
    img = img.reshape(137, 236).numpy()
    img = thresh(img).astype(np.float32)
    img = scale_to_bb(img)
    img = img.reshape(1, 256, 256)
    return torch.from_numpy(img)

def va_tfms(img):
    img = img.reshape(137, 236).numpy()
    img = thresh(img).astype(np.float32)
    img = img.reshape(1, 137, 236)
    return torch.from_numpy(img)

affine_transforms = transforms.Lambda(affine_transforms)
mult = transforms.Lambda(lambda img: img * 255)
to_float = transforms.Lambda(lambda img: img.float())
invert_color = transforms.Lambda(invert_color)
tfms = transforms.Compose([
#     transforms.RandomApply([
        tfms,
#         mult,
#     ], p=0.90),
    to_float,
])

va_tfms = transforms.Compose([
    tfms,
#     mult,
    to_float,
])

tr_ds_tfms = DatasetWithImageTransforms(tr_ds, tfms)
va_ds_tfms = DatasetWithImageTransforms(va_ds, va_tfms)

In [None]:
def plot_from_ds(ds, idx, img_is_tensor=False):
    img, g, v, c, *rest = ds[idx]
    g, v, c = g.item(), v.item(), c.item()
    if img_is_tensor:
        print(img.shape)
        img = img.flatten(end_dim=1)
        print(img.shape)
    plt.imshow(img, cmap='gray', vmin=0., vmax=1.0)
    print(df_class[(df_class['label'] == g) & (df_class['component_type'] == 'grapheme_root')]['component'])
    print(df_class[(df_class['label'] == v) & (df_class['component_type'] == 'vowel_diacritic')]['component'])
    print(df_class[(df_class['label'] == c) & (df_class['component_type'] == 'consonant_diacritic')]['component'])

In [None]:
idx = random.randrange(len(ds_with_aux))
plot_from_ds(ds_with_aux, idx, img_is_tensor=True)

In [None]:
idx = random.randrange(len(tr_ds_tfms))
idx

In [None]:
plot_from_ds(tr_ds, idx, img_is_tensor=True)

In [None]:
plot_from_ds(tr_ds_tfms, idx, img_is_tensor=True)

In [None]:
tr_ds = tr_ds_tfms
va_ds = va_ds_tfms

In [None]:
len(ds), len(tr_ds), len(va_ds)

In [None]:
print(pretrainedmodels.model_names)

In [None]:
!ls /kaggle/input/pretrained-model-weights-pytorch

In [None]:
# torchsummary.summary(model, input_size=(3,128,128))

In [None]:
class MultiTaskNN(nn.Module):
    def __init__(self, base, task_predictors):
        super(MultiTaskNN, self).__init__()
        self.base = base
        self.task_predictors = nn.ModuleList(task_predictors)
        
    def freeze(self):
        for p in self.base.parameters():
            p.requires_grad = False
        
    def unfreeze(self):
        for p in self.base.parameters():
            p.requires_grad = True
        
    def forward(self, x):
        features = self.base(x)
        preds = [predictor(features) for predictor in self.task_predictors]
        return preds

In [None]:
def make_linear_block(in_size, out_size):
    block = nn.Sequential(
        nn.Linear(in_size, out_size), 
#         nn.Dropout(0.5),
        nn.ReLU(), 
        nn.BatchNorm1d(num_features=out_size),
    )
    nn.init.xavier_normal_(block[0].weight.data)
    nn.init.zeros_(block[0].bias.data)
    return block

class ResBlock(nn.Module):
    def __init__(self, layer):
        super(ResBlock, self).__init__()
        self.layer = layer
        
    def forward(self, x):
        return x + self.layer(x)

def make_ff_predictor(in_size, intermediate_size, out_size, layer_count, res_block=False):
    if not res_block:
        layers = [make_linear_block(in_size, intermediate_size)]
    else:
        layers = [ResBlock(make_linear_block(intermediate_size, intermediate_size))]
    for i in range(layer_count):
        if not res_block:
            layers.append(make_linear_block(intermediate_size, intermediate_size))
        else:
            layers.append(ResBlock(make_linear_block(intermediate_size, intermediate_size)))
    layers.append(nn.Linear(intermediate_size, out_size))
    layers = nn.Sequential(*layers)
    return layers

def make_squeeze_predictor(in_size, out_size):
    return nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(in_size, out_size, kernel_size=(1, 1)),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(1,1)),
            nn.Flatten()
        )

# class MultiTaskNN(nn.Module):
#     def __init__(self, n_classes_tasks, depth_tasks):
#         super(MultiTaskNN, self).__init__()
# #         base = models.squeezenet1_0(pretrained=True).features
#         base = pretrainedmodels.__dict__['se_resnext101_32x4d']()
#         # base.load_state_dict(torch.load('./pnasnet5large-bf079911.pth'))
#         base.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
#         # base.dropout = nn.Identity()
#         base.last_linear = nn.Identity()
#         feature_size = 2048
#         # base, feature_size = models.resnet18(pretrained=False), 512
#         # base.load_state_dict(torch.load('./resnet18-5c106cde.pth'))
# #         base = models.wide_resnet101_2(pretrained=True)
#         # base.fc = nn.Identity()
# #         base = models.densenet121(pretrained=True)
# #         base.classifier = nn.Identity()
#         self.base = base
#         self.task_predictors = nn.ModuleList([
#             make_ff_predictor(feature_size, 512, n_classes, depth) 
#             for n_classes, depth in zip(n_classes_tasks, depth_tasks)
#         ])
#         self.n_classes_tasks = n_classes_tasks
#         self.depth_tasks = depth_tasks

#     def convert_to_grayscale(self):
#         with torch.no_grad():
#             # conv1 = nn.Conv2d(3, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
#             # conv1.weight.data = torch.sum(self.base.conv_0.conv.weight.data, dim=1, keepdim=True)
#             # self.base.conv_0.conv = conv1
#             conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#             conv1.weight.data = torch.sum(self.base.layer0.conv1.weight.data, dim=1, keepdim=True)
#             # self.base.conv1 = conv1
#             self.base.layer0.conv1 = conv1
# #             conv1.weight.data = torch.sum(self.base.features.conv0.weight.data, dim=1, keepdim=True)
# #             self.base.features.conv0 = conv1
            
#     def freeze(self):
#         for p in self.base.parameters():
#             p.requires_grad = False
#         # # unfreeze the bns
#         # for param in self.named_parameters():
#         #     if 'bn' in param[0]:
#         #         param[1].requires_grad = True
#         #     if 'downsample.1' in param[0]:
#         #         param[1].requires_grad = True
        
#     def unfreeze(self):
#         for p in self.base.parameters():
#             p.requires_grad = True
        
#     def forward(self, x):
#         features = self.base(x)
#         preds = [predictor(features) for predictor in self.task_predictors]
#         return preds

In [None]:
def train_multi_task_batch(model, optimizer, criterions, mtl_criterion, device, batch, weight_update=True, with_mixup=False):
    img, *labels = batch
    img = img.to(device)
#     img = img / 255.0
    labels = list(map(lambda i: i.to(device), labels))
    if with_mixup:
        alpha = 0.1
#         labels = list(map(lambda i: i.to(device), labels))
        img, labels_a, labels_b, lam = mixup_data(img, torch.from_numpy(np.arange(len(labels[0]))), alpha)
        labels_a, labels_b = [l[labels_a] for l in labels], [l[labels_b] for l in labels]
        preds = model(img)
        losses = tuple([mixup_criterion(criterion, pred, l_a, l_b, lam) for criterion, pred, l_a, l_b in zip(criterions, preds, labels_a, labels_b)])
    else:
        preds = model(img)
        losses = tuple([criterion(pred, label) for criterion, pred, label in zip(criterions, preds, labels)])
    mtl_loss = mtl_criterion(torch.stack(losses))
    mtl_loss.backward()
    if weight_update:
        optimizer.step()
        optimizer.zero_grad()
    return (mtl_loss.item(),) + tuple(map(lambda l: l.item(), losses))

In [None]:
def validate_multi_task_batch(model, criterions, mtl_criterion, device, batch, collapse=True, pred_collapse=True):
    with torch.no_grad():
        img, *labels = batch
        img = img.to(device)
#         img = img / 255.0
        preds = model(img)
        labels = list(map(lambda i: i.to(device), labels))
        losses = tuple([criterion(pred, label) for criterion, pred, label in zip(criterions, preds, labels)])
        mtl_loss = mtl_criterion(torch.stack(losses))
        if not collapse:
            losses = (mtl_loss.tolist(),) + tuple(map(lambda l: l.tolist(), losses))
        else:
            losses = (mtl_loss.item(),) + tuple(map(lambda l: l.item(), losses))
        if pred_collapse:
            preds = tuple(map(lambda p: p.argmax(1).tolist(), preds))
        else:
            preds = tuple(map(lambda p: p.tolist(), preds))
        trues = tuple(map(lambda l: l.tolist(), labels))
        return losses, preds, trues

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=1.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

In [None]:
class LabelSmoothingLoss(nn.Module):
    """
    Probability of correct class will be confidence.
    """
    def __init__(self, confidence, n, reduction='mean'):
        super().__init__()
        # gamma is the probability of each incorrect class
        self.gamma = (1 - confidence) / (n - 1)
        self.n = n
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss(reduction=reduction)
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        logp = self.logsoftmax(inputs)
        ce_loss = self.nllloss(logp, targets)
        if self.reduction == 'mean':
            reg_loss = - torch.mean(logp)
        else:
            reg_loss = - torch.sum(logp, dim=1)
        return (1 - self.n * self.gamma) * ce_loss + self.gamma * reg_loss

In [None]:
class MultiTaskLoss(nn.Module):
    def __init__(self, num_tasks, init_weight=None):
        super(MultiTaskLoss, self).__init__()
        self.n = num_tasks
        if init_weight is None:
            self.w = nn.Parameter(torch.zeros(self.n))
        else:
            self.w = nn.Parameter(torch.tensor(init_weight))
            
    def freeze(self):
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, losses):
        return torch.sum(torch.exp(-2.0 * self.w) * losses) + torch.sum(self.w)

In [None]:
class MultiTaskSumLoss(nn.Module):
    def __init__(self, num_tasks, init_weight=None, collapse=True):
        super(MultiTaskSumLoss, self).__init__()
        self.n = num_tasks
        if init_weight is None:
            self.w = nn.Parameter(torch.ones(self.n))
        else:
            self.w = nn.Parameter(torch.tensor(init_weight))
        self.requires_grad = False
        self.collapse = collapse
        if not self.collapse:
            self.w.data = self.w.data.reshape(-1, 1)

    def forward(self, losses):
        if not self.collapse:
            return torch.sum(losses * self.w, dim=0)
        else:
            return torch.sum(losses * self.w)# - torch.sum(torch.sum(torch.log(self.w)))

In [None]:
# LRFinder
def lr_finder(model, optimizer, criterions, mtl_criterion, device, dl, num_iter=10, start_lr=1e-4, end_lr=1.0, gradient_accumulation_step=1, with_mixup=False):
    model_state =  copy.deepcopy(model.state_dict())
    optim_state =  copy.deepcopy(optimizer.state_dict())
    mtl_criterion_state = copy.deepcopy(mtl_criterion.state_dict())
    for param in optimizer.param_groups:
        param['lr'] = start_lr
    gamma = (end_lr / start_lr) ** (1 / (num_iter * gradient_accumulation_step))
    print(gamma, start_lr, end_lr)
    lrf_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    count = 0
    lrf_losses = []
    lrs = []
    min_loss = math.inf
    done = False
    while not done:
        batches = dl
        batches = notebook.tqdm(tr_dl)
        for batch in batches:
            count += 1
            weight_update = (count % gradient_accumulation_step) == (gradient_accumulation_step - 1)
            losses = train_multi_task_batch(model, optimizer, criterions, mtl_criterion, device, batch, weight_update, with_mixup)
            print(losses)
            min_loss = min(min_loss, losses[0])
            lrf_losses.append(losses)
            if weight_update:
                lrf_sched.step()
            lrs.append([pg['lr'] for pg in optimizer.param_groups])
            if num_iter * gradient_accumulation_step == count:# or losses[0] /10.0 > min_loss:
                done = True
                break
    model.load_state_dict(model_state)
    optimizer.load_state_dict(optim_state)
    mtl_criterion.load_state_dict(mtl_criterion_state)
    return lrf_losses, lrs

In [None]:
aux_depths = []#[3, 3, 3, 3, 4, 4]
aux_n = []#[2, 3, 2, 2, num_fg, num_sg]

In [None]:
base, feature_size = models.resnet18(pretrained=True), 512
base.fc = nn.Identity()
conv1 = nn.Conv2d(1, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
conv1.weight.data = torch.sum(base.conv1.weight.data, dim=1, keepdim=True)
base.conv1 = conv1

In [None]:
n_classes_tasks = [n_graphemes, n_vowel_diacs, n_consonant_diacs] + aux_n
# depth_tasks = [2, 1, 1] + aux_depths
# task_predictors = [
#     make_ff_predictor(feature_size, 512, n_classes, depth) 
#     for n_classes, depth in zip(n_classes_tasks, depth_tasks)
# ]
task_predictors = [nn.Linear(feature_size, n_classes, bias=False) for n_classes in n_classes_tasks]

In [None]:
model = MultiTaskNN(base, task_predictors).to(device)

In [None]:
model.load_state_dict(torch.load('/kaggle/input/bangla-handwritten-grapheme/model.pth', map_location=device), strict=True)

In [None]:
# model.unfreeze()

In [None]:
tr_dl = DataLoader(tr_ds, batch_size=256, num_workers=2, pin_memory=True, shuffle=True, drop_last=True)
va_dl = DataLoader(va_ds, batch_size=256, num_workers=2, pin_memory=True)

In [None]:
g_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=2.0)
v_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=2.0)
c_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=2.0)
nd_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=2.0)
matra_criterion = nn.CrossEntropyLoss()#FocalLoss(gamma=2.0)
up_criterion = nn.CrossEntropyLoss()
conj_criterion = nn.CrossEntropyLoss()
fg_criterion = nn.CrossEntropyLoss()
sg_criterion = nn.CrossEntropyLoss()

In [None]:
criterions = (g_criterion, v_criterion, c_criterion)#, nd_criterion, matra_criterion, up_criterion, conj_criterion, fg_criterion, sg_criterion)

In [None]:
num_tasks = len(criterions)

In [None]:
gradient_accumulation_step = 1

In [None]:
mtl_criterion = MultiTaskSumLoss(num_tasks, np.array([1.0, 1.0, 1.0]) / gradient_accumulation_step).to(device)

In [None]:
# optimizer = optim.Adam([{'params': model.parameters()}, {'params': mtl_criterion.parameters()}])
# optimizer = optim.SGD([{'params': model.parameters()}, {'params': mtl_criterion.parameters()}], lr=1e-2, momentum=0.9)
# optimizer = optim.Adam(model.parameters(), lr=1e-5)
# optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
optimizer = optim.Adam([{'params': model.base.parameters(), 'lr': 1e-4}, 
                        {'params': model.task_predictors[0].parameters(), 'lr': 1e-4},
                        {'params': model.task_predictors[1].parameters(), 'lr': 1e-4},
                        {'params': model.task_predictors[2].parameters(), 'lr': 1e-4},
                        ])

In [None]:
# optimizer.load_state_dict(torch.load('/kaggle/input/bangla-handwritten-grapheme/optim.pth', map_location=device))

In [None]:
# optimizer.param_groups[0]['lr'] = 3e-4

In [None]:
# lrf_losses, lrs = lr_finder(model, optimizer, criterions, mtl_criterion, device, tr_dl, 
#                             num_iter=10, start_lr=1e-5, end_lr=1e-2, gradient_accumulation_step=gradient_accumulation_step, with_mixup=True)

In [None]:
# skip_first = 0
# skip_last = 1

In [None]:
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]])

In [None]:
# plt.plot([t[0] for t in lrs[skip_first:-skip_last]], [t[1] for t in lrf_losses][skip_first:-skip_last])
# plt.xscale('log')

In [None]:
# lr = 1e-6

In [None]:
# lr = 1e-4
# for param in optimizer.param_groups:param['lr'] = lr

In [None]:
tr_losses = [[] for i in range(num_tasks + 1)]
va_losses = [[] for i in range(num_tasks + 1)]
va_scores = []
mtl_weights = []

In [None]:
tr_losses = torch.load('/kaggle/input/bangla-handwritten-grapheme/tr_losses')
va_losses = torch.load('/kaggle/input/bangla-handwritten-grapheme/va_losses')
va_scores = torch.load('/kaggle/input/bangla-handwritten-grapheme/va_scores')

In [None]:
num_epochs = 5
steps_per_epoch = len(tr_dl) // gradient_accumulation_step
steps_per_epoch

In [None]:
lr = 1e-3
scheduler = OneCycleLR(optimizer, lr, epochs=num_epochs, steps_per_epoch=steps_per_epoch)

In [None]:
epochs = range(num_epochs)
# epochs = notebook.tqdm(range(num_epochs))
for epoch in epochs:
    model.train()
    count = 0
    batches = tr_dl
#     batches = notebook.tqdm(tr_dl)
    batch_num = 0
    for batch in batches:
        weight_update = (batch_num % gradient_accumulation_step) == (gradient_accumulation_step - 1)
        losses = train_multi_task_batch(model, optimizer, criterions, mtl_criterion, device, batch, weight_update, with_mixup=True)
#         batches.set_description(("{:0.4f} " * (num_tasks + 1)).format(*losses))
        if weight_update:
            scheduler.step()
        losses = (gradient_accumulation_step * losses[0],) + losses[1:]
        for tr_loss, loss in zip(tr_losses, losses):
            tr_loss.append(loss)
        mtl_weights.append(mtl_criterion.w.tolist())
        batch_num += 1

    print(tuple(map(mean, [l[-steps_per_epoch:] for l in tr_losses])))
    
    model.eval()
    va_batch_losses = [[] for i in range(num_tasks + 1)]
    va_preds = [[] for i in range(num_tasks)]
    va_trues = [[] for i in range(num_tasks)]
    
    batches = va_dl
#     batches = notebook.tqdm(va_dl)
    for batch in batches:
        losses, preds, trues = validate_multi_task_batch(model, criterions, mtl_criterion, device, batch)
        losses = (gradient_accumulation_step * losses[0],) + losses[1:]
        for va_loss, loss in zip(va_batch_losses, losses):
            va_loss.append(loss)
        for va_pred, pred in zip(va_preds, preds):
            va_pred.extend(pred)
        for va_true, true in zip(va_trues, trues):
            va_true.extend(true)
        
    avg_loss = tuple(map(mean, va_batch_losses))
    for va_loss, loss in zip(va_losses, avg_loss):
        va_loss.append(loss)
    recalls = tuple(map(lambda true, pred: recall_score(true, pred, average='macro'), va_preds, va_trues))
    g_rec, v_rec, c_rec, *rest = recalls
    score = 0.5 * g_rec + 0.25 * v_rec + 0.25 * c_rec
    va_scores.append((score,) + recalls)
    print([loss[-1] for loss in va_losses])
    print(va_scores[-1])
    for i in range(1, num_tasks):
        print(confusion_matrix(va_trues[i], va_preds[i]))
    plt.figure(figsize = (20, 20))
    sn.heatmap(np.log1p(confusion_matrix(va_trues[0], va_preds[0])))
    plt.show()
    plt.figure(figsize = (20, 20))
    sn.heatmap(np.log1p(confusion_matrix(va_trues[-2], va_preds[-2])))
    plt.show()
    plt.figure(figsize = (20, 20))
    sn.heatmap(np.log1p(confusion_matrix(va_trues[-1], va_preds[-1])))
    plt.show()

In [None]:
!nvidia-smi

In [None]:
for losses in tr_losses:
    plt.plot(losses)
    plt.show()

In [None]:
for losses in va_losses:
    plt.plot(losses[:])
    plt.show()

In [None]:
plt.plot(va_scores[:])

In [None]:
torch.save(tr_losses, 'tr_losses')
torch.save(va_losses, 'va_losses')
torch.save(va_scores, 'va_scores')

In [None]:
# model.eval()
# va_batch_losses = [[] for i in range(num_tasks + 1)]
# va_preds = [[] for i in range(num_tasks)]
# va_trues = [[] for i in range(num_tasks)]

# batches = va_dl
# batches = notebook.tqdm(va_dl)
# for batch in batches:
#     losses, preds, trues = validate_multi_task_batch(model, criterions, mtl_criterion, device, batch)
#     for va_loss, loss in zip(va_batch_losses, losses):
#         va_loss.append(loss)
#     for va_pred, pred in zip(va_preds, preds):
#         va_pred.extend(pred)
#     for va_true, true in zip(va_trues, trues):
#         va_true.extend(true)
        
# recalls = tuple(map(lambda true, pred: recall_score(true, pred, average='macro'), va_preds, va_trues))

In [None]:
recalls

In [None]:
detailed_recalls = tuple(map(lambda true, pred: recall_score(true, pred, average=None), va_preds, va_trues))

In [None]:
idx = 0
plt.hist(detailed_recalls[idx])
list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
idx = 1
plt.hist(detailed_recalls[idx])
list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
idx = 2
plt.hist(detailed_recalls[idx])
list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
# idx = 3
# plt.hist(detailed_recalls[idx])
# list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
# idx = 4
# plt.hist(detailed_recalls[idx])
# list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
# idx = 5
# plt.hist(detailed_recalls[idx])
# list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
# idx = 6
# plt.hist(detailed_recalls[idx])
# list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
# idx = 7
# plt.hist(detailed_recalls[idx])
# list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
# idx = 8
# plt.hist(detailed_recalls[idx])
# list(enumerate(detailed_recalls[idx])), min(detailed_recalls[idx])

In [None]:
!nvidia-smi 

In [None]:
def get_preds_losses(model, num_taks, device, dl):
    model.eval()
    criterions = [nn.CrossEntropyLoss(reduction='none') for i in range(num_tasks)]
    mtl_criterion = MultiTaskSumLoss(num_tasks, collapse=False).to(device)
    losses = [[] for i in range(num_tasks + 1)]
    preds = [[] for i in range(num_tasks)]
    trues = [[] for i in range(num_tasks)]
    for batch in notebook.tqdm(dl):
        batch_losses, batch_preds, batch_trues = validate_multi_task_batch(model, criterions, mtl_criterion, device, batch, collapse=False)
        for loss, batch_loss in zip(losses, batch_losses):
            loss.extend(batch_loss)
        for pred, batch_pred in zip(preds, batch_preds):
            pred.extend(batch_pred)
        for true, batch_true in zip(trues, batch_trues):
            true.extend(batch_true)
    return losses, preds, trues

In [None]:
# losses, preds, trues = get_preds_losses(model, num_tasks, device, DataLoader(DatasetWithImageTransforms(ds, va_tfms), batch_size=64, num_workers=0, pin_memory=True))

In [None]:
# loss_sorted_ids = sorted(list(range(len(losses[0]))), key=lambda i: losses[0][i], reverse=True)

In [None]:
def get_component(component_id, component_type):
    return df_class[(df_class['label'] == component_id) & (df_class['component_type'] == component_type)]['component'].values[0]

In [None]:
def plot_image_from_ds(ds, i, preds, losses):
    print(i)
    im, g, v, c, *rest_labels = ds[i]#map(lambda t: t.item(), va_ds[i])
    g, v, c = g.item(), v.item(), c.item()
    g_pred, v_pred, c_pred = preds[0][i], preds[1][i], preds[2][i]
#     im = im.permute(1, 2, 0).reshape(137, 236)
    im = im.flatten(end_dim=1)
    plt.imshow(im, cmap='gray', vmin=0., vmax=1.)
    plt.show()
    print(g, get_component(g, 'grapheme_root'))
    print(v, get_component(v, 'vowel_diacritic'))
    print(c, get_component(c, 'consonant_diacritic'))
    print(g_pred, get_component(g_pred, 'grapheme_root'))
    print(v_pred, get_component(v_pred, 'vowel_diacritic'))
    print(c_pred, get_component(c_pred, 'consonant_diacritic'))
    print([l[i] for l in losses])

In [None]:
# ds = DatasetWithImageTransforms(ds, va_tfms)
# count = 0
# for i in loss_sorted_ids:
#     print(count, i in va_indices)
#     plot_image_from_ds(ds, i, preds, losses)
#     count += 1
#     if count == 2000 or losses[0][i] < 0.9:
#         break

In [None]:
# plt.plot([losses[0][i] for i in loss_sorted_ids[:]])

In [None]:
# mean([losses[0][i] for i in loss_sorted_ids[:]])

In [None]:
# mean([losses[0][i] for i in loss_sorted_ids[2000:]])

In [None]:
# func(va_ds, loss_sorted_ids[50], preds, losses)

In [None]:
# plt.plot(mtl_weights)

In [None]:
torch.save(model.state_dict(), 'model.pth')
torch.save(optimizer.state_dict(), 'optim.pth')