<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/modeling_MLPMixer/test_sample_MLPMixer(JAX_Port).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! nvidia-smi

In [None]:
! rm -rf PyTorch-Architectures/
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/

In [3]:
import time
import torch
import torch.nn as nn
from toolkit.custom_dataset_cv import DataLoaderCIFAR10Classification
from modeling_MLPMixer.model_jax_port import MLPMixer
from modeling_MLPMixer.model_jax_port_config import MLPMixerConfig
from toolkit.metrics import cv_compute_accuracy

In [4]:
# Hyperparameters
BS = 64
LR = 3e-4
EPOCHS = 3

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
config = MLPMixerConfig()
config.num_classes = 10
model = MLPMixer(config)
model.to(device)

In [6]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Trainable Parameters: ', params)

Trainable Parameters:  191018


In [None]:
train_loader = DataLoaderCIFAR10Classification(train=True).return_dataloader(batch_size=BS)
valid_loader = DataLoaderCIFAR10Classification(train=False).return_dataloader(batch_size=BS)

print('Length of Train Loader: ', len(train_loader))
print('Length of Valid Loader: ', len(valid_loader))

In [8]:
# Sanity check DataLoaders
for sample in train_loader:
  assert sample[0].dim() == 4, 'Images should be 4-dimensional'
  assert sample[0].size(0) == sample[1].size(0)
  break

In [10]:
# Sanity check forward pass
model.eval()
with torch.set_grad_enabled(False):
  outputs = model(sample[0].to(device), labels=sample[1].to(device))
  assert outputs[1].size(1) == config.num_classes

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [13]:
start_time = time.time()
for epoch in range(EPOCHS):
  model.train()
  for idx, sample in enumerate(train_loader):
    imgs = sample[0].to(device)
    labels = sample[1].to(device)
    outputs = model(imgs, labels=labels)

    optimizer.zero_grad()
    loss = outputs[0]
    loss.backward()
    optimizer.step()

    # LOGGING
    if idx % 300 == 0:
      print('Epochs: %04d/%04d || Batch: %04d/%04d || Loss: %.2f' % (epoch+1,
                                                                     EPOCHS,
                                                                     idx,
                                                                     len(train_loader),
                                                                     loss.item()))
  model.eval()
  with torch.set_grad_enabled(False):
    train_acc = cv_compute_accuracy(model, train_loader, device)
    valid_acc = cv_compute_accuracy(model, valid_loader, device)
  print('Train Accuracy: %.2f%% || Valid Accuracy: %.2f%%' % (train_acc,
                                                              valid_acc))
  epoch_elapsed_time = (time.time() - start_time) / 60
  print('Epoch Elapsed Time: %.2f min' % (epoch_elapsed_time))
total_training_time = (time.time() - start_time) / 60
print('Total Training Time: %.2f min' % (total_training_time))

Epochs: 0001/0003 || Batch: 0000/0782 || Loss: 2.34
Epochs: 0001/0003 || Batch: 0300/0782 || Loss: 1.89
Epochs: 0001/0003 || Batch: 0600/0782 || Loss: 1.64
Train Accuracy: 49.62% || Valid Accuracy: 48.69%
Epoch Elapsed Time: 3.86 min
Epochs: 0002/0003 || Batch: 0000/0782 || Loss: 1.45
Epochs: 0002/0003 || Batch: 0300/0782 || Loss: 1.42
Epochs: 0002/0003 || Batch: 0600/0782 || Loss: 1.30
Train Accuracy: 56.59% || Valid Accuracy: 53.85%
Epoch Elapsed Time: 7.64 min
Epochs: 0003/0003 || Batch: 0000/0782 || Loss: 1.19
Epochs: 0003/0003 || Batch: 0300/0782 || Loss: 1.36
Epochs: 0003/0003 || Batch: 0600/0782 || Loss: 1.16
Train Accuracy: 60.71% || Valid Accuracy: 55.89%
Epoch Elapsed Time: 11.40 min
Total Training Time: 11.40 min
