<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/modeling_MobileNetV2/test_sample_MobileNetV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets
! pip install wget

In [1]:
# ! rm -rf PyTorch-Architectures/
# ! git clone https://github.com/vishal-burman/PyTorch-Architectures.git
%cd PyTorch-Architectures/

/content/PyTorch-Architectures


In [2]:
from tqdm.auto import tqdm
import torch
from toolkit.custom_dataset_cv import DataLoaderCIFAR10Classification
from toolkit.metrics import cv_compute_accuracy
from toolkit.utils import get_optimal_batchsize, dict_to_device, EarlyStopping
from toolkit.utils import get_linear_schedule_with_warmup
from modeling_MobileNetV2.model import MobileNetV2
from modeling_MobileNetV2.config import MobileNetV2Config

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
config = MobileNetV2Config()
model = MobileNetV2(config)
model.to(device)

In [4]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Trainable Parameters: ', params)

Trainable Parameters:  2634560


In [5]:
train_loader = DataLoaderCIFAR10Classification(resize=224, train=True)
valid_loader = DataLoaderCIFAR10Classification(resize=224, train=False)

cifar10 exists...
cifar10 exists...


In [6]:
# get_optimal_batchsize(train_loader.dataset, model) --> 64

64

In [6]:
# Hyperparameters
BS = 64
MAX_EPOCHS = 100
LR = 5e-3

In [7]:
train_loader = train_loader.return_dataloader(batch_size=BS, shuffle=True)
valid_loader = valid_loader.return_dataloader(batch_size=BS, shuffle=False)
print('Length of Train Loader: ', len(train_loader))
print('Length of Valid Loader: ', len(valid_loader))

Length of Train Loader:  782
Length of Valid Loader:  157


In [8]:
# Sanity check forward pass
model.eval()
with torch.set_grad_enabled(False):
  for sample in train_loader:
    outputs = model(**dict_to_device(sample, device))
    loss, logits = outputs[0], outputs[1]
    print(logits.shape, loss.item())
    break

torch.Size([64, 320]) 5.768321514129639


In [9]:
early_stop = EarlyStopping(metric="val_accuracy", verbose=True)

In [10]:
num_training_steps = MAX_EPOCHS * len(train_loader)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=num_training_steps)

In [11]:
progress_bar = tqdm(range(num_training_steps))

for epoch in range(MAX_EPOCHS):
  model.train()
  for sample in train_loader:
      outputs = model(**dict_to_device(sample, device))
      loss = outputs[0]
      loss.backward()

      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()
      progress_bar.update(1)
  model.eval()
  with torch.set_grad_enabled(False):
    valid_acc = cv_compute_accuracy(model, valid_loader, device)
    early_stop(valid_acc, model)
    if early_stop.early_stop:
      print("Early Stopping!")
      break

HBox(children=(FloatProgress(value=0.0, max=78200.0), HTML(value='')))

Validation accuracy increased from -inf% to 49.50%
Validation accuracy increased from 49.50% to 65.71%
EarlyStopping counter: 1 out of 3
Validation accuracy increased from 65.71% to 71.30%
Validation accuracy increased from 71.30% to 75.38%
Validation accuracy increased from 75.38% to 80.54%
EarlyStopping counter: 1 out of 3
Validation accuracy increased from 80.54% to 80.86%
Validation accuracy increased from 80.86% to 81.11%
Validation accuracy increased from 81.11% to 82.18%
EarlyStopping counter: 1 out of 3
EarlyStopping counter: 2 out of 3
Validation accuracy increased from 82.18% to 83.02%
Validation accuracy increased from 83.02% to 83.70%
Validation accuracy increased from 83.70% to 84.45%
Validation accuracy increased from 84.45% to 84.73%
EarlyStopping counter: 1 out of 3
EarlyStopping counter: 2 out of 3
Validation accuracy increased from 84.73% to 85.65%
EarlyStopping counter: 1 out of 3
EarlyStopping counter: 2 out of 3
EarlyStopping counter: 3 out of 3
Early Stopping!


**Loading state_dict for fp_32 model**

In [14]:
model_fp32 = MobileNetV2(config)
model_fp32.load_state_dict(torch.load('checkpoint.pt'))

<All keys matched successfully>

**Valid Accuracy for fp32 MobileNetV2**

In [16]:
model_fp32.to(device)
model_fp32.eval()
valid_acc = cv_compute_accuracy(model_fp32, valid_loader, device)
print(f"Valid Accuracy: {valid_acc}")

Valid Accuracy: 85.64999389648438


**Dynamic Quantization to INT8 for MobileNetV2**

In [17]:
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},
    dtype=torch.qint8,
)

In [19]:
torch.save(model_int8.state_dict(), "checkpoint_int8.pt")

In [21]:
model_int8.to(device)
model_int8.eval()
valid_acc = cv_compute_accuracy(model_int8, valid_loader, device)
print(f"Valid Accuracy: {valid_acc}")

Valid Accuracy: 85.64999389648438


**Size difference between fp32 and int8**

In [22]:
! du -sh *.pt

11M	checkpoint_int8.pt
11M	checkpoint.pt
