In [None]:
%%capture
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.11 --apt-packages libomp5 libopenblas-dev
!pip install -U torch
!pip install -U torchvision
!pip install -U torchtext
!pip install -U wandb
!pip install -U albumentations
!pip install -U opencv-python
!pip install -U tensorboardX

In [None]:
!export XLA_USE_BF16=1

In [None]:
import os
import gc
from datetime import datetime
from collections import defaultdict

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler, AdamW, Adam
from torch.optim.lr_scheduler import _LRScheduler, StepLR
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import warnings
warnings.filterwarnings("ignore")

In [None]:
config = {
    "epochs": 450,
    "lr": 1e-4,
    "weight_decay": 0.998,
    "batch_size": 128,
    "logdir": "/"
}

In [None]:
def set_seed(seed=43):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed()

In [None]:
# Loading the given dataset as per the requirement
class MNIST(Dataset):
    def __init__(self, df, transforms=None, istrain=True):
        super(MNIST, self).__init__()
        self.df = df
        self.transforms = transforms
        self.istrain = istrain

        if self.istrain:
            self.y = df.label
            self.X = df.drop(['label'], axis=1)
        else:
            self.X = df

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        X = self.X.iloc[index].values.reshape((28, 28)).astype(np.uint8)
        X = self.transforms(image=X)['image']

        if self.istrain:
            y = torch.tensor(self.y.iloc[index])
            return X, y
        else:
            return X

In [None]:
data_transforms = {
    "train": A.Compose([
        A.Normalize(
            mean=[0.485], 
            std=[0.229], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        A.augmentations.geometric.rotate.Rotate(
            limit=10,
            p=0.5
        ),
        A.augmentations.geometric.transforms.Affine(
            scale=(0.9, 1.1),
            translate_percent=0.1,
            shear=15,
            p=0.30
        ),
        A.Resize(28, 28),
        ToTensorV2(),
    ], p=1.),
    
    "test": A.Compose([
        A.Normalize(
                mean=[0.485], 
                std=[0.229],
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.cb1 = self.conv_block(1, 16, kernel_size=3, stride=1, bias=False)
        self.cb2 = self.conv_block(16, 64, kernel_size=3, stride=1, bias=False)
        self.lcl = nn.Conv2d(64, 128, kernel_size=3, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.do = nn.Dropout(0.1)
        self.dense = nn.Linear(1152, 256, bias=False)
        self.bn2 = nn.BatchNorm1d(256)
        self.ll = nn.Linear(256, 10)

    def conv_block(self, in_f, out_f, *args, **kwargs):
        return nn.Sequential(
            nn.Conv2d(in_f, out_f, *args, **kwargs),
            nn.BatchNorm2d(out_f),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        x = self.cb2(self.cb1(x))
        x = F.relu(self.bn1(self.lcl(x)))
        x = self.do(x)
        x = x.view(x.size(0), -1)
        x = torch.flatten(x, 1)
        x = F.relu(self.bn2(self.dense(x)))
        x = self.ll(x)

        return F.log_softmax(x, dim=1)

In [None]:
def _train_update(device, x, loss, writer):
    test_utils.print_training_update(
        device, 
        x, 
        loss.item(),
        summary_writer=writer
    )

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, epoch, device, writer):

    model.train()

    dataset_size = 0
    running_loss = 0.0
    train_corr = 0

    batch_size = config['batch_size']

    for step, (image, label) in enumerate(dataloader):

        optimizer.zero_grad()

        output = model(image)

        loss = F.nll_loss(output, label)
        loss.backward()

        xm.optimizer_step(optimizer)

        train_pred = output.argmax(dim=1, keepdim=True)
        train_corr += train_pred.eq(label.view_as(train_pred)).sum()
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
    # xm.add_step_closure(_train_update, args=(device, step, loss, writer), run_async=True)
    epoch_loss = running_loss / dataset_size 
    epoch_loss = xm.mesh_reduce('train_loss', epoch_loss, np.mean)

    train_accuracy = 100 * train_corr.item() / dataset_size
    train_accuracy = xm.mesh_reduce('train_accuracy', train_accuracy, np.mean)

    scheduler.step(metrics=running_loss)

    return epoch_loss, train_accuracy

In [None]:
def valid_epoch(model, dataloader, epoch, device):
    model.eval()

    dataset_size = 0
    running_loss = 0.0
    test_corr = 0

    batch_size = config['batch_size']

    for setp, (image, label) in enumerate(dataloader):

        output = model(image)
        loss = F.nll_loss(output, label)
        
        test_pred = output.argmax(dim=1, keepdim=True)
        test_corr += test_pred.eq(label.view_as(test_pred)).sum()
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
    epoch_loss = running_loss / dataset_size 
    epoch_loss = xm.mesh_reduce('valid_loss', epoch_loss, np.mean)

    valid_accuracy = 100 * test_corr.item() / dataset_size
    valid_accuracy = xm.mesh_reduce('valid_accuracy', valid_accuracy, np.mean)
    
    return epoch_loss, valid_accuracy

In [None]:
def train(model, loaders, optimizer, scheduler, device):

    writer=None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(config['logdir'])
    
    train_loader, valid_loader = loaders

    for epoch in range(1, config['epochs']+1):

        xm.master_print('Epoch {}'.format(epoch))
        start = datetime.now()

        train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, scheduler, epoch, device, writer)

        end = datetime.now()
        time_s = (end-start).total_seconds()
        xm.master_print('Train >> Time(s) {}, Accuracy={:.2f}, Loss={:.5f}'.format(time_s, train_accuracy, train_loss))

        valid_loss, valid_accuracy = valid_epoch(model, valid_loader, epoch, device)
        xm.master_print('Validation >> Accuracy={:.2f} Loss={:.5f}'.format(valid_accuracy, valid_loss))

        test_utils.write_to_summary(
            writer, 
            epoch,
            dict_to_write={'Accuracy/valid': valid_accuracy, 'Accuracy/train': train_accuracy},
            write_xla_metrics=True
        )
        
    test_utils.close_summary_writer(writer)

    return model

In [None]:
def get_loaders(df, device):
    df_train, df_valid = train_test_split(df, test_size=0.2, random_state=42, shuffle=True)
    
    train_dataset = MNIST(df_train, transforms=data_transforms["train"], istrain=True)
    valid_dataset = MNIST(df_valid, transforms=data_transforms["test"], istrain=True)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, 
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset, 
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=train_sampler,
                              num_workers=4, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], sampler=valid_sampler,
                              num_workers=4)
    
    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    valid_device_loader = pl.MpDeviceLoader(valid_loader, device)
    
    return train_device_loader, valid_device_loader

In [None]:
def train_model(config, model, **kwargs):

    device=xm.xla_device()

    model = xmp.MpModelWrapper(model).to(device)

    lr = config['lr'] * xm.xrt_world_size()

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=config['weight_decay'])
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, min_lr=1e-6, factor=0.5)

    df = pd.read_csv('../input/digit-recognizer/train.csv')
    loaders = get_loaders(df, device)

    model = train(model, loaders, optimizer, scheduler, device)

In [None]:
def _map_fn(index, config, model):
    torch.set_default_tensor_type('torch.FloatTensor')
    model = train_model(config, model)

In [None]:
model = Model()
xmp.spawn(_map_fn, args=(config, model,), nprocs=8, start_method='fork')

In [None]:
def test_loader():
    device=xm.xla_device()

    test_df = pd.read_csv('../input/digit-recognizer/test.csv')
    test_dataset = MNIST(test_df, transforms=data_transforms["test"], istrain=False)

    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset, 
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )

    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], num_workers=8, sampler=test_sampler)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)

    return test_device_loader

In [None]:
def predict(model, dataloader):
    device=xm.xla_device()
    model = xmp.MpModelWrapper(model).to(device)
    model.eval()
    
    output = []
    
    for step, image in enumerate(dataloader):        
        with torch.no_grad(): 
            output.append(model(image).cpu().detach())

    preds = torch.cat(output, dim=0)
    preds = torch.argmax(preds, axis=1)

    return preds

In [None]:
preds = predict(model, test_loader())

In [None]:
submission =  pd.DataFrame({
        "ImageId": list(range(1, 28001)),
        "Label": preds.tolist()
    })

submission.to_csv('submission.csv', index=False)