In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import pprint

import torch
import torch.backends
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from agents import Agent, attach_agents
from models import VGG16
from training_utils import validate, train

Hyperparameters

In [2]:
class Args:
    checkpoint = 'best.pth.tar'  # Pretrained VGG16 weights for CIFAR-10.
    num_workers = 0
    batch_size = 256
    lr_agents = 0.01
    lr_model = 0.001
    # epochs = 200 
    penalty = 50  # lambda
    init_weight = 6.9  # Agent's initial weight value.
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')

args = Args()

Load CIFAR-10 dataset

In [3]:
train_data = datasets.CIFAR10('./data', train=True, download=True,
                              transform=transforms.Compose([
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomCrop(32, padding=4),
                                  transforms.RandomRotation(20),
                                  transforms.ToTensor(),
                              ]))

val_data = datasets.CIFAR10('./data', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                            ]))

train_loader = DataLoader(train_data,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=args.num_workers)

val_loader = DataLoader(val_data,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=args.num_workers)

Files already downloaded and verified
Files already downloaded and verified


Load pretrained VGG16 model

In [4]:
checkpoint = torch.load(args.checkpoint, map_location=args.device)
model = VGG16()
model.load_state_dict(checkpoint['state_dict'])
model.to(args.device)
model.eval();

model_orig = copy.deepcopy(model)  # Copy of the original model, just in case.

Print the baseline accuracy

In [5]:
val_acc = validate(model, val_loader, args.device)
print('Baseline accuracy: {}'.format(val_acc))

Val: 100%|██████████| 40/40 [00:02<00:00, 17.73it/s, acc=0.919]

Baseline accuracy: 0.9194





Attach agents

In [6]:
modules_to_prune = [['features.0', 'features.1'],
                    ['features.3', 'features.4'],
                    ['features.7', 'features.8'],
                    ['features.10', 'features.11'],
                    ['features.14', 'features.15'],
                    ['features.17', 'features.18'],
                    ['features.20', 'features.21'],
                    ['features.24', 'features.25'],
                    ['features.27', 'features.28'],
                    ['features.30', 'features.31'],
                    ['features.34', 'features.35'],
                    ['features.37', 'features.38'],
                    ['features.40', 'features.41'],
                    ['classifier.0', 'classifier.1']]

agents, name_to_agent, num_agents, num_subagents = attach_agents(model,
                                                                 modules_to_prune,
                                                                 args.device,
                                                                 args.init_weight)

In [7]:
print('Total agents: {}'.format(sum(agent.num_subagents for agent in agents)))
assert num_subagents == sum(agent.num_subagents for agent in agents)

Total agents: 4736


Training the model

In [8]:
optimizer_agents = optim.Adam([agent.w for agent in agents], lr=args.lr_agents)
# scheduler_agents = StepLR(optimizer_agents, step_size=5, gamma=0.3)

optimizer_model = optim.Adam(model.parameters(), lr=args.lr_model)
# scheduler_model = StepLR(optimizer_model, step_size=5, gamma=0.3)

criterion_model = nn.CrossEntropyLoss().to(args.device)

# Joint training of policies and the model.
for epoch in range(260):
    print('Epoch: {}'.format(epoch))

    train_logs = train(model, agents, train_loader, optimizer_model,
                       optimizer_agents, criterion_model, args.penalty,
                       args.device, optimize_agents=True, optimize_model=True)

    [agent.eval() for agent in agents]
    val_acc = validate(model, val_loader, args.device)

pprint.pprint(train_logs)
print('Val accuracy: {}'.format(val_acc))

# Fine-tuning the model. Stopped policy training.
print('-'*100)
print('Fine-tuning...')

p_list = []
for agent in agents:
    p_list += torch.sigmoid(agent.w).tolist()    
print('Agents p <= 0.5: {}'.format(sum(p <= 0.5 for p in p_list)))

[agent.eval(prob_threshold=0.5, threshold_type='BINARY') for agent in agents]

for epoch in range(260, 300):
    print('Epoch: {}'.format(epoch))

    train_logs = train(model, agents, train_loader, optimizer_model,
                       optimizer_agents, criterion_model, args.penalty,
                       args.device, optimize_agents=False, optimize_model=True)
    val_acc = validate(model, val_loader, args.device)
    
pprint.pprint(train_logs)
print('Val accuracy: {}'.format(val_acc))

Val: 100%|██████████| 40/40 [00:02<00:00, 16.48it/s, acc=0.914]


Epoch: 296


Train: 100%|██████████| 196/196 [00:46<00:00,  4.23it/s, acc=0.993, dropped=530, loss_a=0.102, loss_m=8.97e-5, p=0.859, p_max=1, p_min=0.00075, w=3.78] 
Val: 100%|██████████| 40/40 [00:02<00:00, 16.50it/s, acc=0.914]


Epoch: 297


Train: 100%|██████████| 196/196 [00:46<00:00,  4.24it/s, acc=0.993, dropped=530, loss_a=0.103, loss_m=8.64e-5, p=0.859, p_max=1, p_min=0.00075, w=3.78]
Val: 100%|██████████| 40/40 [00:02<00:00, 16.53it/s, acc=0.917]


Epoch: 298


Train: 100%|██████████| 196/196 [00:46<00:00,  4.23it/s, acc=0.994, dropped=530, loss_a=0.115, loss_m=6.87e-5, p=0.859, p_max=1, p_min=0.00075, w=3.78]
Val: 100%|██████████| 40/40 [00:02<00:00, 16.25it/s, acc=0.915]


Epoch: 299


Train: 100%|██████████| 196/196 [00:46<00:00,  4.23it/s, acc=0.993, dropped=530, loss_a=0.105, loss_m=8.07e-5, p=0.859, p_max=1, p_min=0.00075, w=3.78]
Val: 100%|██████████| 40/40 [00:02<00:00, 16.55it/s, acc=0.915]

{'accuracy': 0.99304,
 'accuracy_reward_avg': 0.64504,
 'compression_reward_avg': 60.0,
 'full_reward_avg': 38.7024,
 'loss_agents_avg': 0.10515757954597472,
 'loss_model_avg': 8.068947711959481e-05,
 'num_channels_dropped': 530.0,
 'p_max': 0.9999513626098633,
 'p_min': 0.0007500865031033754,
 'probabilities_avg': 0.8589694457394736,
 'total_channels': 4736,
 'total_samples': 50000,
 'weights_avg': 3.779762898172651}
Val accuracy: 0.9148





In [9]:
p_list = []
for agent in agents:
    p_list += torch.sigmoid(agent.w).tolist()    

print('Avg p: {}'.format(sum(p_list)/len(p_list)))
print('Min p: {}'.format(min(p_list)))
print('Max p: {}'.format(max(p_list)))
print('Total agents: {}'.format(sum(agent.num_subagents for agent in agents)))
print('Agents p <= 0.5: {}'.format(sum(p <= 0.5 for p in p_list)))
print('Agents p <= 0.9: {}'.format(sum(p <= 0.9 for p in p_list)))
print('Agents p <= 0.998: {}'.format(sum(p <= 0.998 for p in p_list)))
print('Agents p <= 0.999: {}'.format(sum(p <= 0.999 for p in p_list)))

assert len(p_list) == sum(agent.num_subagents for agent in agents)

Avg p: 0.8644156797998074
Min p: 0.0007500865031033754
Max p: 0.9999513626098633
Total agents: 4736
Agents p <= 0.5: 530
Agents p <= 0.9: 1173
Agents p <= 0.998: 4127
Agents p <= 0.999: 4467


Prune Model

In [10]:
from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils import count_flops_params

In [11]:
# Masks for pruning channels. 

masks = {}
for agent in agents:
    agent.eval(prob_threshold=0.5, threshold_type='BINARY')
    masks.update(agent.get_masks())

  self.masks[module_name]['weight'][torch.nonzero(a)] = 1


In [12]:
model_pruned = copy.deepcopy(model_orig)
model_pruned.load_state_dict(model.state_dict());
dummy_input = torch.ones((1, 3, 32, 32)).to(args.device)

ModelSpeedup(model_pruned, dummy_input, masks).speedup_model();

[2023-03-12 23:14:34] [32mstart to speedup the model[0m
[2023-03-12 23:14:34] [32minfer module masks...[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.0[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.1[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.2[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.3[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.4[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.5[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.6[0m
[2023-03-12 23:14:34] [32mUpdate mask for features.7[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.8[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.9[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.10[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.11[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.12[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.13[0m
[2023-03-12 23:14:35] [32mUpdate mask for features.14[0m
[2023-03-12 23

In [13]:
flops1, params1, results1 = count_flops_params(model_pruned, (1, 3, 32, 32), mode='full')

+-------+--------------+-------------+------------------+------------------+------------------+----------+---------+
| Index | Name         |     Type    |   Weight Shape   |    Input Size    |   Output Size    |  FLOPs   | #Params |
+-------+--------------+-------------+------------------+------------------+------------------+----------+---------+
|   0   | features.0   |    Conv2d   |  (48, 3, 3, 3)   |  (1, 3, 32, 32)  | (1, 48, 32, 32)  | 1376256  |   1344  |
|   1   | features.1   | BatchNorm2d |      (48,)       | (1, 48, 32, 32)  | (1, 48, 32, 32)  |  98304   |    96   |
|   2   | features.3   |    Conv2d   |  (57, 48, 3, 3)  | (1, 48, 32, 32)  | (1, 57, 32, 32)  | 25273344 |  24681  |
|   3   | features.4   | BatchNorm2d |      (57,)       | (1, 57, 32, 32)  | (1, 57, 32, 32)  |  116736  |   114   |
|   4   | features.7   |    Conv2d   | (109, 57, 3, 3)  | (1, 57, 16, 16)  | (1, 109, 16, 16) | 14342656 |  56026  |
|   5   | features.8   | BatchNorm2d |      (109,)      | (1, 10

In [14]:
flops2, params2, results2 = count_flops_params(model_orig, (1, 3, 32, 32), mode='full')

+-------+--------------+-------------+------------------+------------------+------------------+----------+---------+
| Index | Name         |     Type    |   Weight Shape   |    Input Size    |   Output Size    |  FLOPs   | #Params |
+-------+--------------+-------------+------------------+------------------+------------------+----------+---------+
|   0   | features.0   |    Conv2d   |  (64, 3, 3, 3)   |  (1, 3, 32, 32)  | (1, 64, 32, 32)  | 1835008  |   1792  |
|   1   | features.1   | BatchNorm2d |      (64,)       | (1, 64, 32, 32)  | (1, 64, 32, 32)  |  131072  |   128   |
|   2   | features.3   |    Conv2d   |  (64, 64, 3, 3)  | (1, 64, 32, 32)  | (1, 64, 32, 32)  | 37814272 |  36928  |
|   3   | features.4   | BatchNorm2d |      (64,)       | (1, 64, 32, 32)  | (1, 64, 32, 32)  |  131072  |   128   |
|   4   | features.7   |    Conv2d   | (128, 64, 3, 3)  | (1, 64, 16, 16)  | (1, 128, 16, 16) | 18907136 |  73856  |
|   5   | features.8   | BatchNorm2d |      (128,)      | (1, 12

Verify accuracy

In [15]:
# Pruned model
val_acc_pruned = validate(model_pruned, val_loader, args.device)
val_acc = validate(model, val_loader, args.device)
print('Pruned acc: {}'.format(val_acc_pruned))
print('acc: {}'.format(val_acc))

Val: 100%|██████████| 40/40 [00:02<00:00, 19.09it/s, acc=0.915]
Val: 100%|██████████| 40/40 [00:02<00:00, 16.69it/s, acc=0.915]

Pruned acc: 0.9148
acc: 0.9148





In [16]:
# penaly = -40
# FLOPs total: 214681294
# #Params total: 10671760

# penalty = -4
# FLOPs total: 43417836
# #Params total: 2016099
# Acc: 0.8941
# PR FLOPs: 86.185633097
# PR Params: 86.552119384

# penalty = -50
# FLOPs total: 242051914
# #Params total: 11911537
# Acc: 0.9148
# PR FLOPs: 22.9857069
# PR Params: 20.5470924

# original
# FLOPs total: 314294794
# # Params total: 14991946
# Acc: 0.9194