This is a script based on this [Discussion](https://www.kaggle.com/c/bengaliai-cv19/discussion/126054) and his original [notebook](https://www.kaggle.com/pestipeti/fast-ensemble-5-folds-20-minutes). 

I made a single model version of his script. The credit is all his, and it's my fault if there is any implementation error in this script.

In [None]:
import gc
import numpy as np
import pandas as pd
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pyarrow.parquet as pq
import cv2

import sys
sys.path.append('/kaggle/input/utilities/')
import SeResNeXt


In [None]:
def crop_image_only_outside(img, tol=0):
    mask = img > tol
    m, n = img.shape
    mask0, mask1 = mask.any(0), mask.any(1)
    col_start, col_end = mask0.argmax(), n - mask0[::-1].argmax()
    row_start, row_end = mask1.argmax(), m - mask1[::-1].argmax()
    return img[row_start:row_end, col_start:col_end]

In [None]:
DATA_PATH = '/kaggle/input/bengaliai-cv19/'
sample_submission = pd.read_csv("../input/bengaliai-cv19/sample_submission.csv")
num_samples = sample_submission.shape[0] // 3

In [None]:
TARGET_SIZE = 64
BATCH_SIZE = 96
N_WORKERS = 4
HEIGHT = 137
WIDTH = 236

class GraphemeValidationDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples
        self.images = torch.zeros(num_samples, TARGET_SIZE, TARGET_SIZE, dtype=torch.uint8)
        img_id = 0
        
        print('start reading in datas.')
        for i in range(4):

            datafile = DATA_PATH + '/test_image_data_{}.parquet'.format(i)
            parq = pq.read_pandas(datafile, columns=[str(x) for x in range(32332)]).to_pandas()
            parq = 255 - parq.iloc[:, :].values.reshape(-1, HEIGHT, WIDTH).astype(np.uint8)

            for idx, image in enumerate(parq):
                image = (image * (255.0 / image.max())).astype(np.uint8)
                image = crop_image_only_outside(image,80)
                image = cv2.resize(image, (TARGET_SIZE, TARGET_SIZE))
                self.images[img_id, :, :] = torch.from_numpy(image.astype(np.uint8))
                img_id = img_id + 1
                
        del parq
        gc.collect()
        print('finish reading in datas.')

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.images[idx].unsqueeze(0)

In [None]:
class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x
    
    
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.backbone = SeResNeXt.se_resnext101()
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone.fc = Identity()
        self.fc1 = nn.Linear(2048, 11)  # vowel_diacritic
        self.fc2 = nn.Linear(2048, 168)  # grapheme_root
        self.fc3 = nn.Linear(2048, 7)  # consonant_diacritic

    def forward(self, x):
        x = x / 255.
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        vowel_diacritic = self.fc1(x)
        grapheme_root = self.fc2(x)
        consonant_diacritic = self.fc3(x)
        return vowel_diacritic, grapheme_root, consonant_diacritic


In [None]:
bengali_dataset = GraphemeValidationDataset(num_samples=num_samples)
data_loader_test = DataLoader(bengali_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS, shuffle=False)

In [None]:
model = MyNet()

MODEL_PATH = '/kaggle/input/submission10/0597.pth'
device = torch.device('cuda:0')
checkpoint = torch.load(MODEL_PATH, map_location=device)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
    model.load_state_dict(checkpoint['state_dict'])
else:
    model.load_state_dict(checkpoint)

model.eval()
model.to(device)

del checkpoint

In [None]:
print('start inference')

results = np.zeros((3, num_samples), dtype=np.int)

tic = time.perf_counter()
for batch_idx, images in enumerate(data_loader_test):

    images = images.float().to(device)

    with torch.no_grad():
        vowel_diacritic, grapheme_root, consonant_diacritic = model(images)

        start = batch_idx * BATCH_SIZE
        end = min((batch_idx + 1) * BATCH_SIZE, num_samples)

        results[0, start:end] = consonant_diacritic.argmax(1).cpu().detach().numpy()
        results[1, start:end] = grapheme_root.argmax(1).cpu().detach().numpy()
        results[2, start:end] = vowel_diacritic.argmax(1).cpu().detach().numpy()
        
    del images
    del vowel_diacritic, grapheme_root, consonant_diacritic

del model
gc.collect()

print('finish inference in {:.2f} sec.'.format(time.perf_counter()-tic))

In [None]:
result_reshape = results.reshape(3*num_samples, order='F')
sample_submission = pd.read_csv("../input/bengaliai-cv19/sample_submission.csv")
sample_submission.target = np.hstack(result_reshape)

sample_submission.to_csv('submission.csv', index=False)
print('finish writing submission.csv')