you can run this notebook diretly on kaggle from following link  
https://www.kaggle.com/chtalhaanwar/pytorch-lightning-mixup-tta-84

# install libraries

In [50]:
%%capture
!pip install pytorch-lightning
!pip install timm
!pip install ttach

# import packages

In [51]:
#import packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from pytorch_lightning import seed_everything, LightningModule, Trainer
from sklearn.utils import class_weight
import torch.nn as nn
import torch
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torchvision
from sklearn.metrics import classification_report
from PIL import Image
from torch.utils.data import DataLoader, Dataset,random_split
import timm
import torchmetrics
import torchvision.models as models
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
import sklearn
from sklearn.metrics import classification_report,accuracy_score

In [52]:
print('torch version',torch.__version__)
print('pytorch lightnging version',pl.__version__)
print('sklearn version',sklearn.__version__)
print('torchvision version',torchvision.__version__)
print('albumentations version',A.__version__)
print('torchmetrics version',torchmetrics.__version__)


torch version 1.11.0
pytorch lightnging version 1.6.3
sklearn version 1.0.2
torchvision version 0.12.0
albumentations version 1.1.0
torchmetrics version 0.6.2


# augmentation

In [53]:
#create data augmentation
img_size=224
aug= A.Compose([
            A.Resize(img_size,img_size),
            A.HorizontalFlip(0.5),
#             A.VerticalFlip(),
            #A.RandomRotate90(),
            A.Rotate(10),
            A.ColorJitter(0.2,0.2,0,0),
            A.Normalize(),
            ToTensorV2(p=1.0),
        ], p=1.0)

# DataReader

In [54]:
#create a class to read data from folders and apply augmentation from albumentation
class DataReader(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        x=self.dataset[index][0]#read image
        y=self.dataset[index][1] #read label
        if self.transform:#apply augmentations
            x=np.array(x)
            x=self.transform(image=x)['image']
        return x, y
    
    def __len__(self):
        return len(self.dataset)

# Pytorch Lightning Model

In [55]:
class OurModel(LightningModule):
    def __init__(self):
        super(OurModel,self).__init__()
        
 
        #parameters
        self.lr=1e-3
        self.batch_size=128
        self.numworker=2
        self.acc = torchmetrics.Accuracy() #metric
        self.criterion=nn.CrossEntropyLoss() #loss function
        #list to store loss and accuracy
        self.trainacc,self.valacc=[],[]
        self.trainloss,self.valloss=[],[]
        #load data        
        self.train_path='../input/human-action-detection-artificial-intelligence/emirhan_human_dataset/datasets/human_data/train_data'
        self.test_path='../input/human-action-detection-artificial-intelligence/emirhan_human_dataset/datasets/human_data/test_data'
        self.dataset=torchvision.datasets.ImageFolder(self.train_path)
        #split data
        self.train_set, self.val_set =random_split(self.dataset,
                            [int(len(self.dataset)*0.7), int(len(self.dataset)*0.3)],
                                                  generator=torch.Generator().manual_seed(42))
        self.test_set=self.dataset=torchvision.datasets.ImageFolder(self.test_path)
   
        # model architecute
        '''refernce:https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py'''
        self.model =  timm.create_model(model_name,pretrained=True,num_classes=len(self.dataset.classes))
    def forward(self,x):
        x= self.model(x)
        return x

    def mixup_data(self,x, y, alpha=1.0):
        '''
        Returns mixed inputs, pairs of targets, and lambda
        reference: mixup: Beyond Empirical Risk Minimization
        '''
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]
        index = torch.randperm(batch_size)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam


    def mixup_criterion(self, pred, y_a, y_b, lam):
        return lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b)
    
    
    def configure_optimizers(self):
        #optimizer and scheduler
        opt=torch.optim.AdamW(params=self.parameters(),lr=self.lr )
        scheduler=CosineAnnealingWarmRestarts(opt,T_0=5, T_mult=1, eta_min=1e-6, last_epoch=-1)
        return {'optimizer': opt,'lr_scheduler':scheduler}

    def train_dataloader(self):#load train 
        return DataLoader(DataReader(self.train_set,aug), batch_size = self.batch_size, 
                          num_workers=self.numworker,
                          pin_memory=True,shuffle=True)

    def training_step(self,batch,batch_idx):
        image,label=batch
        mixed_x, y_a, y_b, lam=self.mixup_data(image,label)#apply mixup
        out = self(mixed_x)#pass images to model
        loss=self.mixup_criterion(out,y_a, y_b, lam) #calculate loss
        acc=self.acc(out,label)#calculate accuracy
        return {'loss':loss,'acc':acc}

    def training_epoch_end(self, outputs):
        #average loss and accuracy in all batches of train data
        loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        acc=torch.stack([x["acc"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        self.trainacc.append(acc)
        self.trainloss.append(loss)
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        
    def val_dataloader(self):
        ds=DataLoader(DataReader(self.val_set,aug), batch_size = self.batch_size,
                      num_workers=self.numworker,pin_memory=True, shuffle=False)
        return ds

    def validation_step(self,batch,batch_idx):
        image,label=batch
        out=self(image)
        loss=self.criterion(out,label)
        acc=self.acc(out,label)
        return {'loss':loss,'acc':acc}

    def validation_epoch_end(self, outputs):
        loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        acc=torch.stack([x["acc"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        self.valacc.append(acc)
        self.valloss.append(loss)
        print('validation loss accuracy ',self.current_epoch,loss, acc)
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        
    def test_dataloader(self):
        ds=DataLoader(DataReader(self.test_set,aug), batch_size = self.batch_size,
                      num_workers=self.numworker,pin_memory=True, shuffle=False)
        return ds    
    def test_step(self,batch,batch_idx):
        image,label=batch
        pred = self(image)
        
        return {'label':label,'pred':pred}

    def test_epoch_end(self, outputs):

        label=torch.cat([x["label"] for x in outputs])
        pred=torch.cat([x["pred"] for x in outputs])
        pred=torch.argmax(pred,1)
        acc=self.acc(pred.flatten(),label)
        pred=pred.detach().cpu().numpy().ravel()
        label=label.detach().cpu().numpy().ravel()

        print('torch acc',acc)
        print(classification_report(label,pred,target_names=model.dataset.classes))
        print('sklearn',accuracy_score(label,pred))

In [56]:
model_name='efficientnetv2_rw_s'
model=OurModel()

In [57]:
lr_monitor = LearningRateMonitor(logging_interval='epoch')
checkpoint=ModelCheckpoint(dirpath='checkpoints',filename='file', monitor='val_acc', verbose=False, save_last=False, mode='max')
trainer = Trainer(max_epochs=15, auto_lr_find=False, auto_scale_batch_size=False,
                deterministic=True,
                gpus=-1,precision=16,
                accumulate_grad_batches=2,
                stochastic_weight_avg=False,
                enable_progress_bar = False,
                num_sanity_val_steps=2,
                callbacks=[lr_monitor,checkpoint]
                 )

# Training and Testing

In [58]:
trainer.fit(model)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


validation loss accuracy  0 3.63 0.07
validation loss accuracy  0 0.98 0.75
validation loss accuracy  1 0.78 0.79
validation loss accuracy  2 0.72 0.81
validation loss accuracy  3 0.71 0.82
validation loss accuracy  4 0.69 0.82
validation loss accuracy  5 0.74 0.8
validation loss accuracy  6 0.68 0.82
validation loss accuracy  7 0.7 0.81
validation loss accuracy  8 0.65 0.83
validation loss accuracy  9 0.65 0.84
validation loss accuracy  10 0.69 0.82
validation loss accuracy  11 0.73 0.81
validation loss accuracy  12 0.68 0.83
validation loss accuracy  13 0.61 0.84
validation loss accuracy  14 0.62 0.83


In [59]:
torch.save(model.state_dict(), 'model.pt')
model.load_state_dict(torch.load('model.pt'))

<All keys matched successfully>

In [60]:
trainer.validate(model)

validation loss accuracy  15 0.63 0.84


[{'val_loss': 0.6299999952316284, 'val_acc': 0.8399999737739563}]

In [61]:
trainer.test(model)

torch acc tensor(0.8250, device='cuda:0')
                    precision    recall  f1-score   support

           calling       0.67      0.81      0.73       200
          clapping       0.83      0.78      0.80       200
           cycling       0.97      0.97      0.97       200
           dancing       0.88      0.81      0.85       200
          drinking       0.83      0.84      0.83       200
            eating       0.94      0.89      0.91       200
          fighting       0.87      0.85      0.86       200
           hugging       0.82      0.90      0.86       200
          laughing       0.81      0.77      0.79       200
listening_to_music       0.76      0.76      0.76       200
           running       0.85      0.94      0.89       200
           sitting       0.73      0.70      0.72       200
          sleeping       0.87      0.84      0.85       200
           texting       0.78      0.72      0.75       200
      using_laptop       0.80      0.81      0.80       2

[{}]

# Testing TTA

In [62]:
loader=model.test_dataloader()
model.cuda().eval()
labels,preds=[],[]
with torch.no_grad():
    for batch in loader:
        image,label=batch
        pred=model(image.cuda())
        pred=torch.argmax(pred,dim=1).detach().cpu().numpy()
        labels.append(label.cpu().numpy())
        preds.append(pred)

In [63]:
from sklearn.metrics import classification_report
print(classification_report(np.hstack(labels),np.hstack(preds),target_names=model.dataset.classes))

                    precision    recall  f1-score   support

           calling       0.67      0.79      0.72       200
          clapping       0.84      0.77      0.80       200
           cycling       0.97      0.96      0.97       200
           dancing       0.86      0.81      0.83       200
          drinking       0.85      0.82      0.84       200
            eating       0.92      0.89      0.91       200
          fighting       0.85      0.85      0.85       200
           hugging       0.81      0.89      0.85       200
          laughing       0.80      0.77      0.79       200
listening_to_music       0.79      0.74      0.76       200
           running       0.85      0.95      0.90       200
           sitting       0.70      0.70      0.70       200
          sleeping       0.88      0.84      0.86       200
           texting       0.75      0.72      0.74       200
      using_laptop       0.79      0.81      0.80       200

          accuracy                    

In [64]:
#test time augmentation
import ttach as tta
transforms = tta.Compose(
    [#more tta can be added
        tta.HorizontalFlip(),
   ]
)

tta_model = tta.ClassificationTTAWrapper(model, transforms)

In [65]:
loader=model.test_dataloader()
model.cuda().eval()
labels,preds=[],[]
with torch.no_grad():
    for batch in loader:
        image,label=batch
        pred=tta_model(image.cuda())
        pred=torch.argmax(pred,dim=1).detach().cpu().numpy()
        labels.append(label.cpu().numpy())
        preds.append(pred)

In [66]:
from sklearn.metrics import classification_report
print(classification_report(np.hstack(labels),np.hstack(preds),target_names=model.dataset.classes))

                    precision    recall  f1-score   support

           calling       0.69      0.81      0.74       200
          clapping       0.85      0.79      0.82       200
           cycling       0.97      0.97      0.97       200
           dancing       0.86      0.83      0.85       200
          drinking       0.85      0.85      0.85       200
            eating       0.92      0.87      0.89       200
          fighting       0.88      0.86      0.87       200
           hugging       0.84      0.92      0.87       200
          laughing       0.81      0.79      0.80       200
listening_to_music       0.78      0.77      0.77       200
           running       0.87      0.94      0.90       200
           sitting       0.74      0.73      0.74       200
          sleeping       0.89      0.85      0.87       200
           texting       0.79      0.73      0.76       200
      using_laptop       0.81      0.82      0.81       200

          accuracy                    