This idea copied from [this kernel](https://www.kaggle.com/abhinand05/vision-transformer-vit-tutorial-baseline)

In [None]:
!pip install ../input/pytorchlightning/tensorboard-2.2.0-py3-none-any.whl
!pip install ../input/pytorchlightning/pytorch_lightning-0.9.0-py3-none-any.whl

In [None]:
import os
import sys
sys.path.append("../input/timm-pytorch-models")

import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.metrics.functional.classification import accuracy

import timm
#timm.list_models("vit*")

import torch
from torch.utils.data import Dataset,DataLoader
from torch import nn

import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensor

# Input Constants

In [None]:
MODEL_PATH = "../input/vit-base-models-pretrained-pytorch/jx_vit_base_p32_384-830016f5.pth"
CLASSES = 5
IMG_DIR = "../input/cassava-leaf-disease-classification/train_images"
IMAGE_SIZE = 384 
TRAIN_FILE = "../input/cassava-leaf-disease-classification/train.csv"

# Model

In [None]:
class CassavViT(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.loss_fn = nn.CrossEntropyLoss()
    self.model = timm.create_model("vit_base_patch32_384",pretrained=False)
    self.model.load_state_dict(torch.load(MODEL_PATH))
    self.model.head.out_features = CLASSES
  
  def forward(self,x):
    return self.model(x)
  
  def configure_optimizers(self):
    LR = 1e-5
    return torch.optim.Adam(self.parameters(),lr=LR)

  def training_step(self,batch,batch_idx):
    x,y = batch["x"],batch["y"]
    y_hat = self(x)
    loss = self.loss_fn(y_hat,y)
    self.log("train_loss",loss)
    acc = accuracy(y_hat,y)
    self.log("train_acc",acc,on_epoch=True,prog_bar=True)
    return loss 

  def validation_step(self,batch,batch_idx):
    x,y = batch["x"],batch["y"]
    y_hat = self(x)
    loss = self.loss_fn(y_hat,y)
    acc = accuracy(y_hat,y)
    self.log("val_acc",acc,on_epoch=True,prog_bar=True)
    self.log("val_loss",loss,prog_bar=True) 

# Dataset Generator

In [None]:
class CassavData(Dataset):
  def __init__(self,path,image_ids,labels,transform):
    super().__init__()
    self.path = path
    self.image_ids = image_ids
    self.labels = labels
    self.transform = transform

  def __len__(self):
    return len(self.image_ids)
  
  def __getitem__(self,item):
    img_id = str(self.image_ids[item])
    label = self.labels[item]
    img = cv2.imread(os.path.join(self.path,img_id))
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = self.transform(image=img)["image"]
    return {
        "x":img,
        "y":label
    }



# Data Module

In [None]:
  
class CassavDataModule(pl.LightningDataModule):
  def __init__(self):
    super().__init__()
    self.train_transform = A.Compose([
                                 A.Resize(IMAGE_SIZE,IMAGE_SIZE),
                                 A.HorizontalFlip(),
                                 A.VerticalFlip(),
                                 A.RandomCrop(IMAGE_SIZE,
                                                   IMAGE_SIZE),
                                 A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                                 ToTensor()])
    self.test_transform = A.Compose([
                                 A.Resize(IMAGE_SIZE,IMAGE_SIZE),
                                 A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                                 ToTensor()])
  
  def setup(self,stage=None):
    dfx = pd.read_csv(TRAIN_FILE)
    dfx["kfold"] = -1
    dfx = dfx.sample(frac=1).reset_index(drop=True)
    skfold = StratifiedKFold(n_splits=5)
    for i,(t_idx,v_idx) in enumerate(skfold.split(dfx.image_id,y=dfx.label)):
      dfx.loc[v_idx,"kfold"]=i
    validation = dfx.loc[dfx.kfold==1]
    train = dfx.loc[dfx.kfold!=1]
    self.train_dataset = CassavData(path=IMG_DIR,
                               image_ids = train.image_id.values,
                               labels = train.label.values,
                               transform = self.train_transform)
    self.val_dataset = CassavData(path=IMG_DIR,
                             image_ids = validation.image_id.values,
                               labels = validation.label.values,
                               transform = self.test_transform)     
  
  def train_dataloader(self):
    return DataLoader(dataset=self.train_dataset,
                      batch_size=16,
                      num_workers=4,
                      drop_last=True,
                      shuffle=True)
  
  def val_dataloader(self):
    return DataLoader(dataset=self.val_dataset,
                      batch_size=16,
                      num_workers=4,
                      drop_last=True,
                      shuffle=False) 

# Training & Save model

In [None]:
#dm = CassavDataModule()
#net = CassavViT() 
#trainer = pl.Trainer(gpus=-1)
#trainer.fit(model = net,
#            datamodule = dm) 
#trainer.save_checkpoints("ViT_model.pth",
#                         max_epochs=20,)
class CassavTestData(Dataset):
  def __init__(self,path,image_ids,transform):
    super().__init__()
    self.path = path
    self.image_ids = image_ids

    self.transform = transform

  def __len__(self):
    return len(self.image_ids)
  
  def __getitem__(self,item):
    img_id = str(self.image_ids[item])
    img = cv2.imread(os.path.join(self.path,img_id))
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = self.transform(image=img)["image"]
    return {
        "x":img,
    }


TEST_IMAGE_DIRS = "../input/cassava-leaf-disease-classification/test_images/"
test = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
test_transform = A.Compose([
                                 A.Resize(IMAGE_SIZE,IMAGE_SIZE),
                                 A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                                 ToTensor()])
test_dataset = CassavTestData(path = TEST_IMAGE_DIRS,
                              image_ids = test.image_id.values,transform = test_transform)
test_loader = DataLoader(test_dataset,
                        batch_size=32,
                        )

In [None]:
#loading the best checkpoints to model
best_checkpoints = "../input/vit-cassava-trained/ViT_model.pth"
pretrained_model = CassavViT.load_from_checkpoint(checkpoint_path = best_checkpoints)
pretrained_model = pretrained_model.to("cuda")
pretrained_model.eval()
pretrained_model.freeze()

fin_out = []
for data in test_loader:
    y_hat = pretrained_model(data["x"].to("cuda"))
    y_hat = torch.argmax(y_hat,dim=1)
    fin_out.extend(y_hat.cpu().detach().numpy().tolist())
test["label"] = fin_out
test[["image_id","label"]].to_csv("submission.csv",index=False)
test.head()
