In [None]:
!pip install ../input/bengaliutils2/timm-0.1.18-py3-none-any.whl

In [None]:
from torch import nn
import timm
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
from timm.models.layers.activations import Swish
from torch.nn import Conv2d, BatchNorm2d, Sequential, Linear
from torch.nn.modules.flatten import Flatten
from albumentations import Compose, Normalize
from albumentations.pytorch import ToTensorV2

In [None]:
class MishFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x * torch.tanh(F.softplus(x))   # x * tanh(ln(1 + exp(x)))

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_variables[0]
        sigmoid = torch.sigmoid(x)
        tanh_sp = torch.tanh(F.softplus(x))
        return grad_output * (tanh_sp + x * sigmoid * (1 - tanh_sp * tanh_sp))

class Mish(nn.Module):
    def forward(self, x):
        return MishFunction.apply(x)

def to_Mish(model):
    for child_name, child in model.named_children():
        if isinstance(child, Swish):
            setattr(model, child_name, Mish())
        else:
            to_Mish(child)

In [None]:
class EfficientNew(nn.Module):
    def __init__(self, num_classes, encoder):
        super().__init__()
        n_channels_dict = {'efficientnet_b0': 1280, 'efficientnet_b1': 1280, 'efficientnet_b2': 1408,
                           'efficientnet_b3': 1536, 'efficientnet_b4': 1792, 'efficientnet_b5': 2048,
                           'efficientnet_b6': 2304, 'efficientnet_b7': 2560, 'seresnext50_32x4d': 2048,
                           'tf_efficientnet_b0_ns': 1280, 'tf_efficientnet_b3_ns': 1536,
                           'tf_efficientnet_b4_ns': 1792}
        self.net = timm.create_model(encoder, pretrained=False)
        to_Mish(self.net)

        out_features = n_channels_dict[encoder]
        self.head_grapheme_root = AverageHead(num_classes[0], out_features)
        self.head_vowel_diacritic = AverageHead(num_classes[1], out_features)
        self.head_consonant_diacritic = AverageHead(num_classes[2], out_features)
        
    def forward(self, x):
        x = self.net.forward_features(x)
        logit_grapheme_root = self.head_grapheme_root(x)
        logit_vowel_diacritic = self.head_vowel_diacritic(x)
        logit_consonant_diacritic = self.head_consonant_diacritic(x)

        return logit_grapheme_root, logit_vowel_diacritic, logit_consonant_diacritic

class AverageHead(nn.Module):
    def __init__(self, num_classes, out_features):
        super().__init__()
        self.post_layers = Sequential(Flatten(), Linear(out_features, num_classes))
        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()

    def forward(self, x):
        x = 0.5 * (F.adaptive_avg_pool2d(x, 1) + F.adaptive_max_pool2d(x, 1))
        return self.post_layers(x)

In [None]:
HEIGHT = 137
WIDTH = 236

def valid_aug(image_size=None):
    augs_list = [Normalize(), ToTensorV2()]
    return Compose(augs_list, p=1)

class GraphemeDatasetTest(Dataset):
    def __init__(self, fname):
        self.transform = valid_aug()
        self.df = pd.read_parquet(fname)
        self.data = self.df.iloc[:, 1:].values.reshape(-1, HEIGHT, WIDTH).astype(np.uint8)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        name = self.df.iloc[idx, 0]
        image = self.data[idx]
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform:
            image = self.transform(image=image)['image']
        return image, name
    
class Predictor:
    def __init__(self, model):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device, dtype=torch.float32);
        self.model.eval()
        print(f'Model prepared. Device is {self.device}')
    
    def predict(self, inputs, softmax_after=True):
        inputs = inputs.to(self.device, dtype=torch.float32)
        with torch.no_grad():
            out_gr, out_vd, out_cd = self.model(inputs)
        if not softmax_after:
            out_gr = torch.nn.functional.softmax(out_gr, dim=1)
            out_vd = torch.nn.functional.softmax(out_vd, dim=1)
            out_cd = torch.nn.functional.softmax(out_cd, dim=1)
        return out_gr, out_vd, out_cd

    def load(self, path):
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
predictors = []

for i, model_path in enumerate([
    '/kaggle/input/bengaliutils2/b3_9863_f0_ft1.pth',
    '/kaggle/input/bengaliutils2/b3_9855_f1_ft1.pth',
    '/kaggle/input/bengaliutils2/b3_9865_f2_ft1.pth',
    '/kaggle/input/bengaliutils2/b3_9854_f3_ft1.pth',
    '/kaggle/input/bengaliutils2/b3_9857_f4_ft1.pth'
]):
    predictor = Predictor(EfficientNew([168, 11, 7], 'tf_efficientnet_b3_ns'))
    predictor.load(model_path)
    predictors.append(predictor)

predictors_count = len(predictors)

def predict_to_numpy(predict):
    return predict.data.cpu().numpy().argmax(axis=1)

def predict_to_numpy_softmax(predict):
    return torch.nn.functional.softmax(predict, dim=1).data.cpu().numpy().argmax(axis=1)

def make_prediction(images, softmax_after=True):
    global was
    outputs_gr = 0
    outputs_vd = 0
    outputs_cd = 0
    for predictor in predictors:
        gr, vd, cd = predictor.predict(images, softmax_after)
        outputs_gr += gr
        outputs_vd += vd
        outputs_cd += cd

    outputs_gr /= predictors_count
    outputs_vd /= predictors_count
    outputs_cd /= predictors_count
        
    if not was:
        print(outputs_vd)
        was = True    
        
    if softmax_after:
        print("Averaging raw then softmax!")
        roots = predict_to_numpy_softmax(outputs_gr)
        vowels = predict_to_numpy_softmax(outputs_vd)
        consonants = predict_to_numpy_softmax(outputs_cd)
    else:
        print("Averaging softmax!")
        roots = predict_to_numpy(outputs_gr)
        vowels = predict_to_numpy(outputs_vd)
        consonants = predict_to_numpy(outputs_cd)    
    
    return roots, vowels, consonants

In [None]:
%%time
import tqdm
target = []
row_id = []
was = False
for i in range(4):
    dataset = GraphemeDatasetTest(f'../input/bengaliai-cv19/test_image_data_{i}.parquet')
    data_loader = DataLoader(dataset, batch_size=256, num_workers=4, shuffle=False)
    for images, images_id in data_loader:
        p1, p2, p3 = make_prediction(images, True)
        for idx, name in enumerate(images_id):
            row_id += [f'{name}_grapheme_root', f'{name}_vowel_diacritic',
                               f'{name}_consonant_diacritic']
            target += [p1[idx].item(), p2[idx].item(), p3[idx].item()]

In [None]:
df_submission = pd.DataFrame(
    {
        'row_id': row_id,
        'target': target
    },
    columns=['row_id','target']
)

df_submission.to_csv('submission.csv', index=False)

df_submission.head(10)