In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import torch.nn.functional as F
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from  torch.cuda.amp import autocast, GradScaler

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
import timm #from efficientnet_pytorch import EfficientNet
from scipy.ndimage.interpolation import zoom
from sklearn.metrics import log_loss

In [None]:
CFG = {
    'fold_num': 5,
    'seed': 324,
    'model_arch': 'SE-Net',#'tf_efficientnet_b4_ns', #'seresnext50_32x4d',
    'img_size': 28,
    'epochs': 10,
    'train_bs': 16,
    'valid_bs': 32,
    'T_0': 10,
    'lr': 1e-2,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 2,
    'accum_iter': 4, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'tta': 15,
    'used_epochs': [8],
    'weights': [1,1,1,1,1]
}

In [None]:
submission = pd.read_csv('../input/Kannada-MNIST/sample_submission.csv')
submission.head()

# Helper Functions

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def show_batch_imgs(examples, df, transforms=None):
    imgs = df.sample(frac=1)[:examples]
    plt.figure(figsize=(8,8))
    for i in range(examples):
        img = imgs.iloc[i][1:].values.astype(np.uint8).reshape((28,28))
        if transforms:
            img = transforms(image=img)['image']
            img = img.numpy().reshape((28,28))
        # print(img.shape)
        plt.subplot(1, examples, i%examples+1)
        plt.axis('off')
        plt.imshow(img)

# Dataset

In [None]:
class KMnistDataset(Dataset):
    def __init__(self, data, transforms=None, do_fmix=None, fmix_params=None, do_cutmix=None, cutmix_params={'alpha':1}):
        self.data = data
        self.transforms = transforms
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        # get images
        img = self.data.iloc[index, 1:].values.astype(np.float32).reshape((28,28))

        if self.transforms:
            img = self.transforms(image=img)['image']

        # if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:

        # if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > 0.5:

        return img

# Define Train\Validation Image Augmentations

In [None]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            # RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            # Transpose(p=0.5), # not good and make confuse
            # HorizontalFlip(p=0.5), # make confuse
            # VerticalFlip(p=0.5), # should be avioded
            ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=30, p=0.5), # should control the rotate limit
            # RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5), # not help in grayscale
            CoarseDropout(p=0.1),
            Cutout(p=0.1),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            # CenterCrop(CFG['img_size'], CFG['img_size'], p=1.), # may lose parts of the character
            # Resize(CFG['img_size'], CFG['img_size']),
            ToTensorV2(p=1.0),
        ], p=1.)

def get_inference_transforms():
    return Compose([
            # CenterCrop(CFG['img_size'], CFG['img_size'], p=1.), # may lose parts of the character
            # Resize(CFG['img_size'], CFG['img_size']),
            ToTensorV2(p=1.0),
        ], p=1.)

# Model

In [None]:
class Sq_Ex_Block(nn.Module):
    def __init__(self, in_ch, r):
        super(Sq_Ex_Block, self).__init__()
        self.se = nn.Sequential(
            GlobalAvgPool(),
            nn.Linear(in_ch, in_ch//r),
            nn.ReLU(inplace=True),
            nn.Linear(in_ch//r, in_ch),
            nn.Sigmoid()
        )

    def forward(self, x):
        se_weight = self.se(x).unsqueeze(-1).unsqueeze(-1)
        return x.mul(se_weight)

class GlobalAvgPool(nn.Module):
    def __init__(self):
        super(GlobalAvgPool, self).__init__()
    def forward(self, x):
        return x.view(*(x.shape[:-2]),-1).mean(-1)

class SE_Net(nn.Module):
    def __init__(self,in_channels):
        super(SE_Net,self).__init__()
        #torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, 
        #                dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.c1 = nn.Conv2d(in_channels=in_channels, out_channels=64,kernel_size=3,stride=1,padding=0)
        self.bn1 = nn.BatchNorm2d(num_features=64,eps=1e-3,momentum=0.01)
        self.c2 = nn.Conv2d(64,64,3,1,0)
        self.bn2 = nn.BatchNorm2d(64,1e-3,0.01)
        self.c3 = nn.Conv2d(64,64,5,1,2)
        self.bn3 = nn.BatchNorm2d(64,1e-3,0.01)
        self.m1 = nn.MaxPool2d(2)
        self.d1 = nn.Dropout(0.4)
        
        self.c4 = nn.Conv2d(64,128,3,1,0)
        self.bn4 = nn.BatchNorm2d(128,1e-3,0.01)
        self.c5 = nn.Conv2d(128,128,3,1,0)
        self.bn5 = nn.BatchNorm2d(128,1e-3,0.01)
        self.c6 = nn.Conv2d(128,128,5,1,2)
        self.bn6 = nn.BatchNorm2d(128,1e-3,0.01)
        self.m2 = nn.MaxPool2d(2)
        self.d2 = nn.Dropout(0.4)
        
        self.c7 = nn.Conv2d(128,256,3,1,0)
        self.bn7 = nn.BatchNorm2d(256,1e-3,0.01)
        self.se3 = Sq_Ex_Block(in_ch=256,r=8)
        self.m3 = nn.MaxPool2d(2)
        self.d3 = nn.Dropout(0.4)

        self.fc1 = nn.Linear(256*1*1,256)
        self.bn8 = nn.BatchNorm1d(256,1e-3,0.01)
        
        self.out = nn.Linear(256,10)
        
        self.init_linear_weights()
        
    def forward(self,x):
        x = self.bn1(F.leaky_relu(self.c1(x),0.1))
        x = self.bn2(F.leaky_relu(self.c2(x),0.1))
        x = self.bn3(F.leaky_relu(self.c3(x),0.1))
        x = self.d1(self.m1(x))
        
        x = self.bn4(F.leaky_relu(self.c4(x),0.1))
        x = self.bn5(F.leaky_relu(self.c5(x),0.1))
        x = self.bn6(F.leaky_relu(self.c6(x),0.1))
        x = self.d2(self.m2(x))
        
        x = self.bn7(F.leaky_relu(self.c7(x),0.1))
        x = self.se3(x)
        x = self.d3(self.m3(x))
        
        x = x.view(-1, 256*1*1) #reshape
        x = self.bn8(F.leaky_relu(self.fc1(x),0.1))
        return self.out(x)
    
    def init_linear_weights(self):
        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_in')  #default mode: fan_in
        nn.init.kaiming_normal_(self.out.weight, mode='fan_in')

# Main Loop

In [None]:
def inference_one_epoch(model, data_loader, device):
    model.eval()

    image_preds_all = []
    
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        
        image_preds = model(imgs)   #output = model(input)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
        
    
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

In [None]:
seed_everything(CFG['seed'])

test_data = pd.read_csv("../input/Kannada-MNIST/test.csv")

test_ds = KMnistDataset(test_data, transforms=get_inference_transforms())

tst_preds = []

device = torch.device(CFG['device'])

In [None]:
test_data.head()

In [None]:
model_paths = ['../input/eesm5720/SE-Net_fold_0_9', 
               '../input/eesm5720/SE-Net_fold_1_9',
               '../input/eesm5720/SE-Net_fold_2_9',
               '../input/eesm5720/SE-Net_fold_3_9',
               '../input/eesm5720/SE-Net_fold_4_9',]

In [None]:
for o, model_path in enumerate(model_paths):
    model = SE_Net(in_channels=1).to(device)

    tst_loader = torch.utils.data.DataLoader(
        test_ds, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    
    print("using the model from {}".format(model_path))
    
    model.load_state_dict(torch.load(model_path))

    with torch.no_grad():
        for _ in range(CFG['tta']):
            tst_preds += [1/CFG['tta']*inference_one_epoch(model, tst_loader, device)]

tst_preds = np.mean(tst_preds, axis=0) 

In [None]:
test_data['label'] = np.argmax(tst_preds, axis=1)

In [None]:
test_data = pd.DataFrame(test_data, columns=['id','label'])

In [None]:
test_data.to_csv('submission.csv', index=False)

For the training model, please refer to [EESM 5720 SE-Net baseline for Kannada MNIST](https://www.kaggle.com/dongjai04/se-net-baseline-eesm5720-train)

Thank you for the [notebook](https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-inference-tta) from [khyeh0719](https://www.kaggle.com/khyeh0719) and the [notebook](https://www.kaggle.com/ccchang801023/se-net-my-top1-baseline-model-with-pytorch) from [ccchang801023](https://www.kaggle.com/ccchang801023).