# SWIN Transformer
#### Architecture in torchvision==0.13.0+cu113

### Import headers
Requires:
- `torch` - PyTorch for base-handling
- `torchvision - 0.13.0+cu113` - torchvision for getting SWIN Transformer model
- `torch.utils.data.DataLoader` - Efficient dataset loading for training
- `tqdm_notebook` - Progress bars for Jupyter Notebooks
- `prefetch_generator.BackgroundGenerator` - Allows pre-fetching of next mini-batches for quicker training (requires extra RAM memory)
- `torchmetrics` - Used for metrics and evaluation
- `torchviz` - Used for visualizing the model architecture
- `torch.utils.tensorboard` - Used for accessing and uploading model-training metrics to TensorBoard
- `os` - Used for simple file-handling operations
- `torchvision.transforms.functional` - Used for exposing torchvision.transforms as functions instead of nn.Module-compatible layers
- `matplotlib.pyplot` - Used for plotting image-based results
- `warnings` - Used for suppressing simple warnings

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from prefetch_generator import BackgroundGenerator
!pip install torchviz torchmetrics -q
import torchmetrics
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
import os
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(54)

### - Defining Transformations
### - Downloading CIFAR-100 Dataset for Training and Testing
### - Defining DataLoader for training 
(adjust batch-size with respect to on-device GPU Memory)

In [None]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.RandomHorizontalFlip(), 
    torchvision.transforms.RandomRotation(45)
])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_ds = torchvision.datasets.CIFAR100(root="./train", download=True, train=True, transform=train_transforms)
test_ds = torchvision.datasets.CIFAR100(root="./test", train=False, transform=test_transforms)

train_load = DataLoader(train_ds, batch_size=24, shuffle=True, num_workers=4, pin_memory=True)
test_load = DataLoader(test_ds, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)

### Reading class names from dataset meta-data

In [None]:
import tarfile
import pickle
tar = tarfile.open("./train/cifar-100-python.tar.gz", "r:gz")
member_list = tar.getmembers()
f = tar.extractfile(member_list[4])
content = pickle.loads(f.read())
class_names = content['fine_label_names']

def getLabel(index):
    return class_names[index]

### - Instantiating the SWIN Transformer model from torchvision
### - Changing the head of the model from 1000-output to 100-output `nn.Linear()` layer
### - Checking presence of NVIDIA GPU and appropriate CUDA drivers

In [None]:
swin = torchvision.models.swin_b()
swin.head = torch.nn.Linear(1024, 100, bias=True)
!nvidia-smi

### - Testing model by passing 1 sample image through it
### - Visualizing model architecture 
(as the model is too huge, `make_dot` instead saves the result as a .png file for later viewing)

In [None]:
img = next(iter(train_ds))[0]

batch_img = img.float()
batch_img.unsqueeze_(0)

yhat = swin(batch_img)

make_dot(yhat, params=dict(list(swin.named_parameters()))).render("swin", format="png")

### Defining optimizer and Loss function

In [None]:
optimizer = torch.optim.Adam(swin.parameters(), amsgrad=True)
loss_function = torch.nn.BCEWithLogitsLoss()

### Defining helper function for returning class name from Label-encoded integers

In [None]:
def getLabel(index):
    return class_names[index]

### Test device to check if GPU is available or not

In [None]:
def try_gpu_else_cpu():
        devices = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
        return devices if devices else [torch.device('cpu')]
device = try_gpu_else_cpu()[0]
print(device)  

### Training function - Takes input for number of epochs to train
- Allows setting for use of Tensorboard, but increases processing time considerably
- Saves model weights and optimizer weights after processing of every epoch in `./swin-epochs`

In [None]:
def fit(n_epochs, model, train_loader, optimizer, loss_fn, tensorboard=False):

    model = model.to(device)

    running_loss = 0
    
    if tensorboard:
        writer = SummaryWriter()

    CURRENT_DIRECTORY = os.getcwd()
    EPOCH_DIRECTORY = os.path.join(CURRENT_DIRECTORY, 'swin-epochs')
    if not os.path.exists(EPOCH_DIRECTORY):
        os.mkdir(EPOCH_DIRECTORY)

    for epoch in range(n_epochs):
        for i, data in tqdm_notebook(enumerate(train_loader), 0, len(train_loader)):

            img, lbl = data
            img, lbl = img.float().to(device=device), lbl.double().to(device=device)

            optimizer.zero_grad()
            y = model(img)
            loss = loss_fn(lbl, torch.argmax(y, axis=1).double()).to(device=device)
            loss.requires_grad = True
            loss.backward()
            optimizer.step()
            
            if tensorboard:
                writer.add_scalar('Loss/Train', loss, i)
                writer.add_graph(model, img)


        CHECKPOINT_PATH = os.path.join(EPOCH_DIRECTORY, f'model_ckpt_epoch{epoch+1}.pt')
        torch.save({
            "model.state_dict" : model.state_dict(),
            "optimizer.state_dict" : optimizer.state_dict(),
            "epoch":epoch
        }, CHECKPOINT_PATH)
    
    if tensorboard:
        writer.flush()
        writer.close()
    
    print(f'Epoch {epoch+1}/{n_epochs} completed')

n_epoch = int(input("Enter no. of epochs to train for: "))
fit(n_epoch, swin, train_load, optimizer, loss_function)

### Function for prediction on test-set and calculation of Accuracy

In [None]:
def predict(test_loader, model):
    model.eval()
    preds = []
    actuals = []
    for img, lbl in tqdm_notebook(test_loader, 0, len(test_loader)):
        img, lbl = img.float().to(device=device), lbl.to(device=device)
        y = model(img)
        preds.extend(torch.argmax(y, axis=1).int().tolist())
        actuals.extend(lbl.tolist())
    accuracy = torchmetrics.functional.accuracy(torch.Tensor(preds).int(), torch.Tensor(actuals).int()).item()
    print('Accuracy: ', accuracy*100, '%')
predict(test_load, swin)

### Function to display images along with predicted and actual labels

In [None]:
def show(imgs, lbls, true_lb):
    imgs = list(imgs.cpu())
    imgs = [F.to_pil_image(i) for i in imgs]
    lbls_list = [getLabel(i) for i in lbls.tolist()]
    true_lb_list = [getLabel(i) for i in true_lb.tolist()]
    fig, ax = plt.subplots((len(imgs) // 2), 2, squeeze=False, figsize=(10, 10))
    fig.subplots_adjust(bottom = 0.15)
    fig.tight_layout()
    for i, img in enumerate(imgs):
        ax[(i // 2), i%2].set_title(f"Prediction : {lbls_list[i]}\nActual : {true_lb_list[i]}")
        ax[(i // 2), i%2].imshow(img)
        ax[(i // 2), i%2].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
img, lbl = next(iter(test_load))
img = img.float().cuda()
preds = swin(img)
lbl = lbl.cuda()
preds = torch.argmax(preds, axis=1)
show(img, preds, lbl)

### Concluding thoughts:

The model is fairly impaired due to the small image size of CIFAR-100 (32x32 with 3-channels) but with drop-in replacements of larger images along with better computing power enabled by GPUs, it can be much stronger and can predict much more reliably
The model required 01:58 minutes to train for 1 epoch over 50000 samples of CIFAR-100 on a NVIDIA GTX-1650Ti Mobile 4GB GPU