In [95]:
!pip install --upgrade pip
!pip install --upgrade torch torchvision



In [96]:
!nvidia-smi

Fri Nov 26 14:25:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    73W / 149W |   2764MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [97]:
import torch
import os
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from tqdm import tqdm
import numpy as np
from time import time

In [98]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [99]:
print('==> Preparing data..')

transform_train = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [100]:
def train(num_epoch, net, device, criterion, optimizer, dataloader):
  net = net.train()
  net = net.to(device)
  criterion = criterion.to(device)
  
  for epoch in range(num_epoch):
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    
    for batch_idx, (inputs, targets) in pbar:
      inputs, targets = inputs.to(device), targets.to(device)
      outputs = net(inputs)
      loss = criterion(outputs, targets)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      with torch.no_grad():
        train_loss += loss.data
        outputs = outputs.argmax(dim=1)
        correct += (targets == outputs).sum()
        total += inputs.size(0)

      pbar.set_description(f'[Epoch {epoch}] Loss: {train_loss / total:.4f}, Accuracy: {correct / total * 100:.4f}%')


In [101]:
def test(net, device, criterion, dataloader):
  net.eval()
  net = net.to(device)
  test_loss = 0
  correct = 0
  num_data = 0
  cur = time()
  
  torch.save(net.state_dict(), "tmp.pt")
  model_size = "%.2f MB" %(os.path.getsize("tmp.pt") / 1e6)


  with torch.no_grad():
    pbar = tqdm(dataloader, total=len(dataloader))
    for data, target in pbar:
      data, target = data.to(device), target.to(device)
      output = net(data)
      test_loss += criterion(output, target).data
      output = output.argmax(dim=1)
      correct += (target == output).sum()
      num_data += data.size(0)

      pbar.set_description(f'Test set: Average loss: {test_loss / num_data:.4f}, Accuracy: {correct / num_data * 100:.4f}%, Time cost: {time() - cur:.4f}, Model size: {model_size}')

  os.remove("tmp.pt")
    

In [102]:
def pruning(net):
  for name, module in net.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
      prune.l1_unstructured(module, 'weight', amount=0.7)
      prune.remove(module, 'weight')

In [108]:
net = models.quantization.resnet50(pretrained=True)

net.fc = nn.Linear(2048, 10)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
epoch = 1

for i in range(5):
  train(epoch, net, device, criterion, optimizer, trainloader)
  print("Before Pruning...")
  test(net, device, criterion, testloader)
  pruning(net)
  print("After Pruning...")
  test(net, device, criterion, testloader)

Epoch 0 Loss: 0.0136, Accuracy: 70.0060%: 100%|██████████| 782/782 [02:02<00:00,  6.39it/s]


Before Pruning...


Test set: Average loss: 0.0085, Accuracy: 81.2300%, Time cost: 7.4376, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.54it/s]


After Pruning...


Test set: Average loss: 0.0225, Accuracy: 49.5300%, Time cost: 7.4069, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.62it/s]
Epoch 0 Loss: 0.0087, Accuracy: 81.2900%: 100%|██████████| 782/782 [02:04<00:00,  6.28it/s]


Before Pruning...


Test set: Average loss: 0.0081, Accuracy: 82.2600%, Time cost: 7.3360, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.83it/s]


After Pruning...


Test set: Average loss: 0.0113, Accuracy: 76.6700%, Time cost: 7.3624, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.72it/s]
Epoch 0 Loss: 0.0070, Accuracy: 84.7900%: 100%|██████████| 782/782 [02:04<00:00,  6.29it/s]


Before Pruning...


Test set: Average loss: 0.0077, Accuracy: 83.6800%, Time cost: 7.4134, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.62it/s]


After Pruning...


Test set: Average loss: 0.0085, Accuracy: 82.3000%, Time cost: 7.3398, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.92it/s]
Epoch 0 Loss: 0.0058, Accuracy: 87.2920%: 100%|██████████| 782/782 [02:04<00:00,  6.27it/s]


Before Pruning...


Test set: Average loss: 0.0076, Accuracy: 83.5300%, Time cost: 7.4562, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.50it/s]


After Pruning...


Test set: Average loss: 0.0075, Accuracy: 83.7100%, Time cost: 7.2901, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 22.12it/s]
Epoch 0 Loss: 0.0047, Accuracy: 89.8080%: 100%|██████████| 782/782 [02:04<00:00,  6.29it/s]


Before Pruning...


Test set: Average loss: 0.0076, Accuracy: 84.1100%, Time cost: 7.3470, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.84it/s]


After Pruning...


Test set: Average loss: 0.0070, Accuracy: 85.0300%, Time cost: 7.4521, Model size: 94.43 MB: 100%|██████████| 157/157 [00:07<00:00, 21.51it/s]


In [109]:
net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
net.fuse_model()
net = torch.quantization.prepare_qat(net)
train(epoch, net, device, criterion, optimizer, trainloader)
pruning(net)

# models.quantization.utils.quantize_model(net, "fbgemm")
test(net, 'cpu', criterion, testloader)
torch.quantization.convert(net, inplace=True)
test(net, 'cpu', criterion, testloader)

  reduce_range will be deprecated in a future release of PyTorch."
Epoch 0 Loss: 0.0026, Accuracy: 96.4960%: 100%|██████████| 782/782 [02:29<00:00,  5.23it/s]
Test set: Average loss: 0.0070, Accuracy: 85.0700%, Time cost: 159.2476, Model size: 94.78 MB: 100%|██████████| 157/157 [02:39<00:00,  1.01s/it]
Test set: Average loss: 0.0070, Accuracy: 84.9600%, Time cost: 32.4488, Model size: 24.12 MB: 100%|██████████| 157/157 [00:32<00:00,  4.85it/s]
