In [None]:
!pip install /kaggle/input/timm-model/timm-0.3.2-py3-none-any.whl

In [None]:
import pandas as pd
import cv2
import numpy as np
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
import time
import pyarrow.parquet as pq

In [None]:
IMG_SIZE=128
N_CHANNELS=1
def resize(df, size=128, need_progress_bar=True):
    resized = {}
    resize_size=128
    if need_progress_bar:
        for i in range(df.shape[0]):
            image=df.loc[df.index[i]].values.reshape(137,236)
            _, thresh = cv2.threshold(image, 30, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) #  
            contours, _ = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)[-2:]

            idx = 0 
            ls_xmin = []
            ls_ymin = []
            ls_xmax = []
            ls_ymax = []
            for cnt in contours:
                idx += 1
                x,y,w,h = cv2.boundingRect(cnt)
                ls_xmin.append(x)
                ls_ymin.append(y)
                ls_xmax.append(x + w)
                ls_ymax.append(y + h)
            xmin = min(ls_xmin)
            ymin = min(ls_ymin)
            xmax = max(ls_xmax)
            ymax = max(ls_ymax)

            roi = image[ymin:ymax,xmin:xmax]
            #roi = image
            resized_roi = cv2.resize(roi, (resize_size, resize_size),interpolation=cv2.INTER_AREA)
            resized[df.index[i]] = resized_roi.reshape(-1)
    resized = pd.DataFrame(resized).T
    return resized

In [None]:
import torch
import numpy as np
import tqdm
import csv
import os
import cv2
import pandas as pd
from timm.models import create_model
import torch.autograd.profiler as profiler
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
from torch.nn import Parameter
from torch.autograd import Variable
from timm.models import create_model

class MultiHeadSimpleModel(nn.Module):
  
  def __init__(self, input_size, backbone, pretrained=True, dropout=0.2):
    super(MultiHeadSimpleModel, self).__init__()
    self.input_size = input_size
    self.backbone = backbone
    self.pretrained = pretrained
    self.dropout=dropout

    self.backbone = create_model(self.backbone, self.pretrained, num_classes=0)


    feature_size = self.backbone.state_dict()['bn2.weight'].shape[0]
    self.root_head = nn.Linear(feature_size, 168, bias=True)
    self.consonant_head = nn.Linear(feature_size, 168, bias=True)
    self.vowel_head = nn.Linear(feature_size, 168, bias=True)


    intermid = []
    intermid.append(nn.BatchNorm1d(feature_size))
    intermid.append(nn.Dropout(self.dropout))
    intermid.append(nn.Linear(feature_size, 512))
    intermid.append(nn.BatchNorm1d(512))


    self.intermid_unique = nn.ModuleList(intermid)
    torch.nn.init.kaiming_normal_(self.intermid_unique[2].weight)
    self.arc_face_unique = Arcface(512, 1295)

  def multi_head(self, input, unique):
    input = self.backbone(input)

    root = self.root_head(input)
    consonant = self.consonant_head(input)
    vowel = self.vowel_head(input)

    x = self.intermid_unique[0](input)
    for inter in self.intermid_unique[1:]:
      x = inter(x)
    x = F.normalize(x, p=2, dim=1)
    unique_p = self.arc_face_unique(x, unique)

    return root, consonant, vowel, unique_p


  def forward(self, input, unique):
    multi_head_outputs = self.multi_head(input, unique)

    return multi_head_outputs

def l2_norm(input,axis=1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output

class Arcface(nn.Module):
    # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599    
    def __init__(self, embedding_size=512, classnum=51332,  s=30., m=0.5):
        super(Arcface, self).__init__()
        self.classnum = classnum
        self.weight = Parameter(torch.Tensor(embedding_size,classnum))
        # initial kernel
        self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.m = m # the margin value, default is 0.5
        self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.mm = self.sin_m * m  # issue 1
        self.threshold = math.cos(math.pi - m)
    def forward(self, embbedings, label):
        # weights norm
        nB = len(embbedings)
        kernel_norm = l2_norm(self.weight,axis=0)
        # cos(theta+m)
        cos_theta = torch.mm(embbedings,kernel_norm)
#         output = torch.mm(embbedings,kernel_norm)
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        cos_theta_2 = torch.pow(cos_theta, 2)
        sin_theta_2 = 1 - cos_theta_2
        sin_theta = torch.sqrt(sin_theta_2)
        cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_theta - self.threshold
        cond_mask = cond_v <= 0
        keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
        cos_theta_m[cond_mask] = keep_val[cond_mask]
        output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
        idx_ = torch.arange(0, nB, dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        output *= self.s # scale up in order to make softmax work, first introduced in normface
        return output


class BengaliDataset(Dataset):
  def __init__(self, label_csv, unique_csv, train_folder, parquet_path, transforms, cache=True, test=True):
    self.label_csv = label_csv
    self.unique_csv = unique_csv
    self.train_folder = train_folder
    self.parquet_path = parquet_path
    self.label = pd.read_csv(self.label_csv)
    self.label = self.label[self.label['component']=='grapheme_root']
    self.label = self.label.reset_index(drop=True)
    #self.label = pd.read_csv(self.label_csv)
    unique_df = pd.read_csv(self.unique_csv)

    self.names = self.label['image_id'].values
    self.uniques = unique_df.grapheme.unique()
    self.transforms = transforms
    self.img = [None] * self.label.shape[0]
    self.test = test

    if cache:
      self.cache_images()

  def cache_images(self):
    count = 0
    root = './test_graphemes/'
#     for ii in range(4):
#         p = pd.read_parquet(f'/kaggle/input/bengaliai-cv19/test_image_data_{ii}.parquet').drop(['image_id'], axis=1)#pd.read_parquet('/kaggle/input/bengaliai-cv19/test_image_data_0.parquet')
#     #     p2 = pq.read_pandas('/kaggle/input/bengaliai-cv19/train_image_data_1.parquet').to_pandas()#pd.read_parquet('/kaggle/input/bengaliai-cv19/test_image_data_1.parquet')
#     #     p3 = pq.read_pandas('/kaggle/input/bengaliai-cv19/train_image_data_2.parquet').to_pandas()#pd.read_parquet('/kaggle/input/bengaliai-cv19/test_image_data_2.parquet')
#     #     p4 = pq.read_pandas('/kaggle/input/bengaliai-cv19/train_image_data_3.parquet').to_pandas()#pd.read_parquet('/kaggle/input/bengaliai-cv19/test_image_data_3.parquet')
#     #     self.test_df = pd.concat([p1, p2, p3, p4]).reset_index(drop=True).drop(['image_id'], axis=1)


#         pbar = tqdm.tqdm(range(p.shape[0]), position=0, leave=True)
#         pbar.set_description('writing images...')
#         for i, _ in enumerate(pbar):
#           name = self.names[count]
#           #self.img[count] = self.transforms(resize(p.iloc[[i]]).values.astype(np.uint8).reshape(-1).reshape(IMG_SIZE, IMG_SIZE))
#           cv2.imwrite(root+name+'.jpg', resize(p.iloc[[i]]).values.reshape(-1).reshape(IMG_SIZE, IMG_SIZE).astype(np.uint8))
#           count += 1
        
#         del p

    self.p = p = pd.read_parquet(self.parquet_path).drop(['image_id'], axis=1)
    
#     count = 0
#     pbar = tqdm.tqdm(range(self.label.shape[0]))
#     pbar.set_description('caching images...')
#     for i, _ in enumerate(pbar):
#         name = self.names[i]
#         img = Image.open(os.path.join('./test_graphemes/', name+'.jpg'))
#         self.img[i] = self.transforms(img)

  def load_image(self, idx):
    img = resize(self.p.iloc[[idx]]).values.reshape(-1).reshape(IMG_SIZE, IMG_SIZE).astype(np.uint8)
    return self.transforms(img)
        
    img = self.img[idx]
    if img is None:
      #name = self.label.loc[idx]['image_id']
      img = resize(p.iloc[[i]]).values.reshape(-1).reshape(IMG_SIZE, IMG_SIZE).astype(np.uint8)
      #img = Image.open(os.path.join(self.train_folder, name+'.jpg'))
      return self.transforms(img)
    else:
      return self.transforms(img)

  def __getitem__(self, idx):
    if not self.test:
      img = self.load_image(idx)
      #img = self.img[idx]
      root = self.label.loc[idx]['grapheme_root']
      consonant = self.label.loc[idx]['consonant_diacritic']
      vowel = self.label.loc[idx]['vowel_diacritic']
      unique = np.where(self.uniques == self.label.grapheme[idx])[0][0]
      return transforms.ToTensor()(img), root, consonant, vowel, unique
    else:
      img = self.load_image(idx)
      #img = self.img[idx]
      root = 0
      consonant = 0
      vowel = 0
      unique = 0
      return transforms.ToTensor()(img), root, consonant, vowel, unique

  def __len__(self):
    return self.p.shape[0]

def postprocess(preds, num_classes, EXP = -1.2):
    p0 = np.argmax(preds,axis=1)

    s = pd.Series(p0)
    vc = s.value_counts().sort_index()
    df = pd.DataFrame({'a':np.arange(num_classes),'b':np.ones(num_classes)})
    df.b = df.a.map(vc)
    df.fillna(df.b.min(),inplace=True)
    mat1 = np.diag(df.b.astype('float32')**EXP)

    p0 = np.argmax(preds.dot(mat1), axis=1)
    
    return p0

class Tester:
    def __init__(self,
               dataset_path='./drive/MyDrive/datasets/car classification/train_data', 
               batch_size=1, 
               model_name='tf_efficientnet_b3_ns', 
               test_csv='./train_labels.csv', 
               unique_csv='./train_labels.csv',
               output_dir='../drive/MyDrive/ckpt/grapheme/submission.csv',
               ckpt='../drive/MyDrive/ckpt/grapheme/20.pth',
               multihead_ckpt='../'):

        # initialize attributes
        self.dataset_path = dataset_path
        self.batch_size = batch_size
        self.model_name = model_name
        self.test_csv = test_csv
        self.unique_csv = unique_csv
        self.output_dir = output_dir
        self.ckpt = ckpt
        self.multihead_ckpt = multihead_ckpt
        
        if model_name == 'tf_efficientnet_b0_ns':
            self.input_size = (224, 224)
        elif model_name == 'tf_efficientnet_b2_ns':
            self.input_size = (260, 260)
        elif model_name == 'tf_efficientnet_b3_ns':
            self.input_size = (300, 300)
        elif model_name == 'tf_efficientnet_b4_ns':
            self.input_size = (380, 380)
        elif model_name == 'tf_efficientnet_b6_ns':
            self.input_size = (528, 528)
        else:
            raise Exception('non-valid model name')
        
        # Compose transforms
        transform = []
        transform += [transforms.ToPILImage()]
        transform += [transforms.Resize(self.input_size)]
        self.transform = transforms.Compose(transform)

       
        self.device = torch.device('cuda')
        self.model_multihead = MultiHeadSimpleModel(self.input_size, self.model_name, pretrained=False, dropout=0).to(self.device)


        multihead_ckpt = torch.load(self.multihead_ckpt)
        self.model_multihead.load_state_dict(multihead_ckpt['model_multihead_state_dict'])

    def test(self):
        output_roots = []
        output_consonants = []
        output_vowels = []
        count = 0
        for ii in range(4):
            self.test_dataset = BengaliDataset(self.test_csv, self.unique_csv, self.dataset_path, f'/kaggle/input/bengaliai-cv19/test_image_data_{ii}.parquet', self.transform, cache=True)
            self.names = self.test_dataset.names
            self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=0, shuffle=False)
            
            pbar = tqdm.tqdm(self.test_dataloader)
            pbar.set_description('testing process')
            self.model_multihead.eval()

            with torch.no_grad():
                for it, data in enumerate(pbar):
                    inputs = data[0].to(self.device)
                    inputs = inputs.repeat(1, 3, 1, 1)
                    roots = data[1].to(self.device).long()
                    consonants = data[2].to(self.device).long()
                    vowels = data[3].to(self.device).long()
                    uniques = data[4].to(self.device).long()

                    
                    root, consonant, vowel, unique = self.model_multihead(inputs, uniques)
                    
                
                    root = postprocess(torch.nn.Softmax(1)(root).cpu().numpy(), 168, -1.2)
                    consonant = postprocess(torch.nn.Softmax(1)(consonant[:,:8]).cpu().numpy(), 8, -0.5)
                    vowel = postprocess(torch.nn.Softmax(1)(vowel[:,:11]).cpu().numpy(), 11, -1.2)

                    
                    for index in range(inputs.shape[0]):
                
                        output_roots.append(root[index].item())
                        output_consonants.append(consonant[index].item())
                        output_vowels.append(vowel[index].item())                 
            
            del self.test_dataset.p
            del self.test_dataset
            del self.test_dataloader

        row_id, target = [], []
        for iid, r, c, v in zip(self.names, output_roots, output_consonants, output_vowels):
            row_id.append(iid + '_grapheme_root')
            target.append(int(r))
            row_id.append(iid + '_vowel_diacritic')
            target.append(int(v))
            row_id.append(iid + '_consonant_diacritic')
            target.append(int(c))
            count += 1

        sub_fn = self.output_dir
        sub = pd.DataFrame({'row_id': row_id, 'target': target})
        sub.to_csv(sub_fn, index=False)
        print(f'Done wrote to {sub_fn}')

tester = Tester(dataset_path='./test_graphemes', 
               batch_size=128, 
               model_name='tf_efficientnet_b2_ns', 
               test_csv='/kaggle/input/bengaliai-cv19/test.csv', 
               unique_csv='/kaggle/input/bengaliai-cv19/train.csv',
               output_dir='./submission.csv',
               ckpt='/kaggle/input/graphemesingle/27.pth',
               multihead_ckpt='/kaggle/input/graphemesingle/10.pth'
               )
tester.test()