### Import Libraries

In [1]:
import random
import pandas as pd

# import Linformer
from linformer import Linformer
from itertools import chain   
from vit_pytorch.efficient import ViT   
from tqdm.notebook import tqdm   
from __future__ import print_function

import torch   
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms   
from torch.optim.lr_scheduler import StepLR   
from torch.utils.data import DataLoader, Dataset

#sklearn to split the data
from sklearn.model_selection import train_test_split   

### Config Params

In [None]:
#definining batch size, epocs, learning rate and gamma for training  

batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7 #for learning rate scheduler 
seed = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Torch transforms
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)
val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)
test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

### Download data and Spitting them

In [None]:
from utils import download_data
from utils import img_show

train_list, test_list = download_data()

#Splitting train and validation list
train_list, valid_list = train_test_split(train_list, test_size=0.2, stratify=labels, random_state=seed)
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

#### Display images

In [None]:
#Defining labels
labels = [path.split('/')[-1].split('.')[0] for path in train_list]
img_show(train_list, labels)

In [None]:
from utils import CatsDogsDataset
#defining train, validation and test dataset

train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)

#loading dataloader
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

In [None]:
#Line transformer
efficient_transformer = Linformer(
    dim=256,
    seq_len=256+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

#Visual transformer 

model = ViT(
    dim=256,
    image_size=224,
    patch_size=16,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)

In [None]:
criterion = nn.CrossEntropyLoss() # loss function
optimizer = optim.Adam(model.parameters(), lr=lr) # optimizer
scheduler = StepLR(optimizer, step_size=1, gamma=gamma) # scheduler

In [None]:
from utils import train
for epoch in range(epochs):
    train(train_loader, valid_loader, device, model, criterion, optimizer, epoch)
print("Training Complete")