In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
import glob
import io
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
from PIL import Image
import time
#from tqdm.notebook import tqdm

import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

import warnings
warnings.filterwarnings("ignore")

# Read tf-record for Pytorch

Based on https://medium.com/analytics-vidhya/how-to-read-tfrecords-files-in-pytorch-72763786743f

In [None]:
train_files = glob.glob("../input/tpu-getting-started/tfrecords-jpeg-224x224/train/*.tfrec")
val_files = glob.glob("../input/tpu-getting-started/tfrecords-jpeg-224x224/val/*.tfrec")
test_files = glob.glob("../input/tpu-getting-started/tfrecords-jpeg-224x224/test/*.tfrec")

In [None]:
train_feature_description = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

In [None]:
def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, train_feature_description)

In [None]:
train_ids = []
train_class = []
train_images = []
for i in train_files:
    train_image_dataset = tf.data.TFRecordDataset(i)
    
    train_image_dataset = train_image_dataset.map(_parse_image_function)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset]
    train_ids = train_ids + ids

    classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
    train_class = train_class + classes

    images = [image_features['image'].numpy() for image_features in train_image_dataset]
    train_images = train_images + images

In [None]:
val_ids = []
val_class = []
val_images = []
for i in val_files:
    val_image_dataset = tf.data.TFRecordDataset(i)
    
    val_image_dataset = val_image_dataset.map(_parse_image_function)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in val_image_dataset]
    val_ids = val_ids + ids

    classes = [int(class_features['class'].numpy()) for class_features in val_image_dataset]
    val_class = val_class + classes

    images = [image_features['image'].numpy() for image_features in val_image_dataset]
    val_images = val_images + images

In [None]:
test_feature_description = {
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

In [None]:
def _parse_image_function_test(example_proto):
    return tf.io.parse_single_example(example_proto, test_feature_description)

In [None]:
test_ids = []
test_images = []
for i in test_files:
    test_image_dataset = tf.data.TFRecordDataset(i)
    
    test_image_dataset = test_image_dataset.map(_parse_image_function_test)

    ids = [str(id_features['id'].numpy())[2:-1] for id_features in test_image_dataset]
    test_ids = test_ids + ids

    images = [image_features['image'].numpy() for image_features in test_image_dataset]
    test_images = test_images + images

# Dataset

In [None]:
import cv2

In [None]:
class FlowerDS():
    def __init__(self, ids, cls, imgs, transforms, is_test=False):
        self.ids = ids
        if not is_test:
            self.cls = cls
        self.imgs = imgs
        self.transforms = transforms
        self.is_test = is_test
    
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        img = Image.open(io.BytesIO(img))
        img = self.transforms(img)
        if self.is_test:
            return img, -1, self.ids[idx]
        return img, int(self.cls[idx]), self.ids[idx]

# normalize stats

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [None]:
normalize = transforms.Normalize(mean=mean, std=std)

# Vision Transformer

We'll try vision transformer from recent paper https://arxiv.org/abs/2010.11929 <br>
The implementation can be found @ https://github.com/nachiket273/VisTrans<br>
I have created simple library for the same and it can be installed using<br>
pip install vistrans<br>
Further info can be found @ https://pypi.org/project/vistrans/

In [None]:
!pip uninstall vistrans -y
!pip install vistrans

In [None]:
from vistrans import VisionTransformer

In [None]:
def save_checkpoint(model, is_best, filename='./checkpoint.pth'):
    """Save checkpoint if a new best is achieved"""
    if is_best:
        xm.save(model.state_dict(), filename)  # save checkpoint
    else:
        print ("=> Validation Accuracy did not improve")

In [None]:
def load_checkpoint(model, filename = './checkpoint.pth'):
    sd = torch.load(filename, map_location=lambda storage, loc: storage)
    names = set(model.state_dict().keys())
    for n in list(sd.keys()):
        if n not in names and n+'_raw' in names:
            if n+'_raw' not in sd: sd[n+'_raw'] = sd[n]
            del sd[n]
    model.load_state_dict(sd)

# Model and load weights

In [None]:
VisionTransformer.list_pretrained()

In [None]:
def get_model(name ='vit_b16_224'):
    model = VisionTransformer.create_pretrained(name, num_classes=104)
    for param in model.parameters():
        param.require_grad = True
    return model

In [None]:
SERIAL_EXEC = xmp.MpSerialExecutor()
WRAPPED_MODEL = xmp.MpModelWrapper(get_model())

# Stats

In [None]:
class AvgStats(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.losses =[]
        self.precs =[]
        self.its = []
        
    def append(self, loss, prec, it):
        self.losses.append(loss)
        self.precs.append(prec)
        self.its.append(it)

In [None]:
trn_stat = AvgStats()
val_stat = AvgStats()

# Fit

In [None]:
def fit(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    bs = flags['bs']
    epochs = flags['epochs']
    WRAPPED_MODEL = flags['model']
    torch.manual_seed(719)
    device = xm.xla_device()
    
    def get_dataset():
        train_transforms = transforms.Compose([
                        transforms.RandomResizedCrop(224),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ToTensor(),
                        normalize,
                        transforms.RandomErasing()
                    ])

        test_transforms = transforms.Compose([
                        transforms.CenterCrop(224),
                        transforms.Resize(224),
                        transforms.ToTensor(),
                        normalize
                    ])

        train_ds = FlowerDS(train_ids, train_class, train_images, train_transforms)
        valid_ds = FlowerDS(val_ids, val_class, val_images, test_transforms)

        return train_ds, valid_ds
    
    train_ds, valid_ds = SERIAL_EXEC.run(get_dataset)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_ds,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_ds,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    
    train_loader = DataLoader(train_ds, bs, sampler=train_sampler, num_workers=1, pin_memory=True)
    valid_loader = DataLoader(valid_ds, bs, sampler=valid_sampler, num_workers=1, pin_memory=True)
    
    
    model = WRAPPED_MODEL.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=3e-2*xm.xrt_world_size(), momentum=0.9)
    
    def train(loader, epoch, model, optimizer, criterion):
        #tracker = xm.RateTracker()
        model.train()
        running_loss = 0.
        running_acc = 0.
        tot = 0
        start_time = time.time()
        for i, (ip, tgt, _) in enumerate(loader):
            #ip, tgt = ip.to(device), tgt.to(device)                            
            output = model(ip)
            loss = criterion(output, tgt)
            running_loss += loss.item()
            tot += ip.shape[0]

            # Append outputs
            _, pred = output.max(dim=1)
            running_acc += torch.sum(pred == tgt.data)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            #optimizer.step()
            xm.optimizer_step(optimizer)

        trn_time = time.time() - start_time        
        trn_acc = (running_acc/tot) * 100
        trn_loss = running_loss/len(loader)
        return trn_acc, trn_loss, trn_time
    
    def test(model, loader, criterion):
        with torch.no_grad():
            model.eval()
            running_loss = 0.
            running_acc = 0.
            tot = 0
            start_time = time.time()
            for i, (ip, tgt, _) in enumerate(loader):
                #ip, tgt = ip.to(device), tgt.to(device)
                output = model(ip)
                loss = criterion(output, tgt)
                running_loss += loss.item()
                tot += ip.shape[0]
                _, pred = output.max(dim=1)
                running_acc += torch.sum(pred == tgt.data)

            val_time = time.time() - start_time
            val_acc = (running_acc/tot) * 100
            val_loss = running_loss/len(loader)
            return val_acc, val_loss, val_time
        
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 3e-5)
    #sched = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[7, 14, 20])
    for j in range(1, epochs+1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        trn_acc, trn_losses, trn_time = train(para_loader.per_device_loader(device), j, model,
                                             optimizer, criterion)
        trn_stat.append(trn_losses, trn_acc, trn_time)
        para_loader = pl.ParallelLoader(valid_loader, [device])
        val_acc, val_losses, val_time = test(model, para_loader.per_device_loader(device), criterion)
        val_stat.append(val_losses, val_acc, val_time)            
        sched.step()
        print("Epoch::{}, Trn_loss::{:06.8f}, Val_loss::{:06.8f}, Trn_F1::{:06.8f}, Val_F1::{:06.8f}"
            .format(j, trn_losses, val_losses, trn_acc, val_acc))
        
    save_checkpoint(model, True, './best_model.pth')

In [None]:
flags = dict()

In [None]:
flags['epochs'] = 25
flags['bs'] = 32
flags['model'] = WRAPPED_MODEL

In [None]:
xmp.spawn(fit, args=(flags,), nprocs=8, start_method='fork')

In [None]:
!mv best_model.pth best_model_vit_b16_224.pth

In [None]:
WRAPPED_MODEL1 = xmp.MpModelWrapper(get_model('vit_l16_224'))

In [None]:
flags['bs'] = 16
flags['model'] = WRAPPED_MODEL1

In [None]:
xmp.spawn(fit, args=(flags,), nprocs=8, start_method='fork')

In [None]:
test_transforms = transforms.Compose([
                        transforms.CenterCrop(224),
                        transforms.Resize(224),
                        transforms.ToTensor(),
                        normalize
                    ])

In [None]:
test_ds = FlowerDS(test_ids, [], test_images, test_transforms, True)

In [None]:
device = xm.xla_device()

In [None]:
testloader = DataLoader(test_ds, 16, num_workers=4, pin_memory=True, shuffle=False)

In [None]:
def predict(loader, device):
    with torch.no_grad():
        torch.cuda.empty_cache()
        model.eval()
        model1.eval()
        preds = dict()
        for i, (ip, _, ids) in enumerate(loader):
            ip = ip.to(device)
            output = model(ip)
            _, pred = output.max(dim=1)
            output1 = model1(ip)
            _, pred1 = output1.max(dim=1)
            for i, j, k in zip(ids, pred.cpu().detach(), pred1.cpu().detach()):
                preds[i] = int((j.item() + k.item())/2)
        return preds

In [None]:
model = WRAPPED_MODEL1.to(device)

In [None]:
load_checkpoint(model, './best_model.pth')

In [None]:
model1 = WRAPPED_MODEL.to(device)

In [None]:
load_checkpoint(model1, './best_model_vit_b16_224.pth')

In [None]:
preds = predict(testloader, device)

In [None]:
import csv

In [None]:
sub_csv = pd.read_csv('../input/tpu-getting-started/sample_submission.csv')

In [None]:
sub_csv.head()

In [None]:
for key in preds.keys():
    sub_csv.loc[sub_csv['id'] == key, 'label'] = preds[key]

In [None]:
sub_csv.head()

In [None]:
sub_csv.to_csv('submission.csv', index=False)