In [1]:
#setup estential
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import os
from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver
from prune.utils import analyse_model, finetune
from utils import *

import json 
from config import cfg

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def get_pack():
    set_seeds()
    pack = recover_pack()

    # model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg['base']['cuda'] else 'cuda')
    # pack.net.module.load_state_dict(model_dict)

    GBNs = GatedBatchNorm2d.transform(pack.net)
    for gbn in GBNs:
        gbn.extract_from_bn()

    pack.optimizer = optim.SGD(
        pack.net.parameters() ,
        lr=2e-3,
        momentum=cfg['train']['momentum'],
        weight_decay=cfg['train']['weight_decay'],
        nesterov=cfg['train']['nesterov']
    )

    return pack, GBNs

def clone_model(net):
    model = get_model()
    gbns = GatedBatchNorm2d.transform(model.module)
    model.load_state_dict(net.state_dict())
    return model, gbns

def eval_prune(pack):   # for evaluate the pruned model
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    cloned.module.classifier = FinalLinearObserver(cloned.module.classifier)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
    del cloned
    del cloned_pack
    
    return flops, params

def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
    LOGS = []
    flops_save_points = set([30, 20, 10])
    iter_idx = 0

    pack.tick_trainset = pack.train_loader
    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg["gbn"]["sparse_lambda"], flops_eta = cfg["gbn"]["flops_eta"], minium_filter = 3)
    prune_agent.tock(lr_min=cfg["gbn"]["lr_min"], lr_max=cfg["gbn"]["lr_max"], tock_epoch=cfg["gbn"]['tock_epoch'])
    while True:
        left_filter = prune_agent.total_filters - prune_agent.pruned_filters
        num_to_prune = int(left_filter * cfg["gbn"]["p"])
        info = prune_agent.prune(num_to_prune, tick=True, lr=cfg["gbn"]['lr_min'])
        flops, params = eval_prune(pack)
        info.update({
            'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
            'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
        })
        LOGS.append(info)
        print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % 
            (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
        
        iter_idx += 1
        if iter_idx % cfg.gbn.T == 0:
            print('Tocking:')
            prune_agent.tock(lr_min=cfg["gbn"]['lr_min'], lr_max=cfg["gbn"]['lr_max'], tock_epoch=cfg['gbn']['tock_epoch'])

        flops_ratio = flops/BASE_FLOPS * 100
        for point in [i for i in list(flops_save_points)]:
            if flops_ratio <= point:
                torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                flops_save_points.remove(point)

        if len(flops_save_points) == 0:
            break


In [7]:
pack, GBNs = get_pack()  #initial before pruning
cloned, _ = clone_model(pack.net)
BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32).cpu())
print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
print('%.3f M' % (BASE_PARAM / 1e6))
del cloned

prune(pack, GBNs, BASE_FLOPS, BASE_PARAM)

==> Preparing Cifar10 data..
Files already downloaded and verified
Files already downloaded and verified




odict_keys(['module'])
odict_keys(['features', 'classifier'])
odict_keys(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44'])
odict_keys(['features', 'classifier'])
odict_keys(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44'])
314.308 MFLOPS
14.728 M


100%|██████████| 391/391 [12:26<00:00,  1.91s/it]


Tock - 0,	 Test Loss: 1.3953,	 Test Acc: 53.07, Final LR: 0.00280


100%|██████████| 391/391 [12:21<00:00,  1.90s/it]


Tock - 1,	 Test Loss: 0.9007,	 Test Acc: 69.41, Final LR: 0.00460


100%|██████████| 391/391 [11:13<00:00,  1.72s/it]


Tock - 2,	 Test Loss: 0.7410,	 Test Acc: 75.10, Final LR: 0.00640


100%|██████████| 391/391 [13:27<00:00,  2.07s/it]


Tock - 3,	 Test Loss: 0.7679,	 Test Acc: 74.76, Final LR: 0.00820


100%|██████████| 391/391 [11:37<00:00,  1.78s/it]


Tock - 4,	 Test Loss: 0.6178,	 Test Acc: 78.86, Final LR: 0.01000


100%|██████████| 391/391 [11:40<00:00,  1.79s/it]


Tock - 5,	 Test Loss: 0.5262,	 Test Acc: 81.71, Final LR: 0.00820


100%|██████████| 391/391 [11:48<00:00,  1.81s/it]


Tock - 6,	 Test Loss: 0.4688,	 Test Acc: 84.00, Final LR: 0.00640


100%|██████████| 391/391 [12:02<00:00,  1.85s/it]


Tock - 7,	 Test Loss: 0.3957,	 Test Acc: 86.74, Final LR: 0.00460


100%|██████████| 391/391 [12:06<00:00,  1.86s/it]


Tock - 8,	 Test Loss: 0.3514,	 Test Acc: 88.29, Final LR: 0.00280


100%|██████████| 391/391 [12:09<00:00,  1.87s/it]


Tock - 9,	 Test Loss: 0.3342,	 Test Acc: 88.63, Final LR: 0.00100


100%|██████████| 391/391 [08:54<00:00,  1.37s/it]


odict_keys(['features', 'classifier'])
odict_keys(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44'])


AssertionError: Torch not compiled with CUDA enabled

In [None]:
test_vgg

DataParallel(
  (module): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace=True)
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0

In [11]:
pack.net.module

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): GatedBatchNorm2d(
      64 -> 64 | ID: 6fc0a556-df90-11ee-b60b-fc34974a25ab
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): GatedBatchNorm2d(
      64 -> 64 | ID: 6fc0cc5c-df90-11ee-98a7-fc34974a25ab
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): GatedBatchNorm2d(
      128 -> 128 | ID: 6fc0cc5d-df90-11ee-b082-fc34974a25ab
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 12

In [None]:
GBNs

In [2]:
import torch
from torchviz import make_dot
from pytorch_model_summary import summary
incep_v3_model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
incep_v3_model.eval()

Using cache found in C:\Users\User/.cache\torch\hub\pytorch_vision_v0.10.0


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [4]:
yhat = incep_v3_model(torch.rand((1,3,299,299)))

In [None]:
# from torchviz import make_dot
# 
# make_dot(yhat, params=dict(list(incep_v3_model.named_parameters()))).render("rnn_torchviz", format="png")

In [9]:
pytorch_total_params = sum(p.numel() for p in incep_v3_model.parameters())
print(pytorch_total_params)

27161264


In [10]:
# pip uninstall graphviz


SyntaxError: invalid syntax (189567022.py, line 2)

In [8]:
pip install graphviz

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
# test_01 = get_model()

GBNs = GatedBatchNorm2d.transform(incep_v3_model)

odict_keys(['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'maxpool1', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'maxpool2', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'AuxLogits', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'avgpool', 'dropout', 'fc'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['branch1x1', 'branch5x5_1', 'branch5x5_2', 'branch3x3dbl_1', 'branch3x3dbl_2', 'branch3x3dbl_3', 'branch_pool'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['branch1x1', 'branch5x5_1', 'branch5x5_2', 'branch3x3dbl_1', 'branch3x3dbl_2', 'branch3x3dbl_3', 'branch_pool'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn'])
odict_keys(['conv', 'bn

In [7]:
count = 0
for i in incep_v3_model.modules():
    print('==============================')
    print(i)
    print('==============================')
    count+=1
    print(count)

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [8]:
GBNs = GatedBatchNorm2d.transform(incep_v3_model)

22
2
2
2
2
2
7
2
2
2
2
2
2
2
7
2
2
2
2
2
2
2
7
2
2
2
2
2
2
2
4
2
2
2
2
10
2
2
2
2
2
2
2
2
2
2
10
2
2
2
2
2
2
2
2
2
2
10
2
2
2
2
2
2
2
2
2
2
10
2
2
2
2
2
2
2
2
2
2
3
2
2
6
2
2
2
2
2
2
9
2
2
2
2
2
2
2
2
2
9
2
2
2
2
2
2
2
2
2


In [10]:
incep_v3_model.modules

<bound method Module.modules of Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): GatedBatchNorm2d(
      32 -> 32 | ID: 6f40accb-df82-11ee-84b7-fc34974a25ab
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): GatedBatchNorm2d(
      32 -> 32 | ID: 6f424e06-df82-11ee-adc6-fc34974a25ab
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): GatedBatchNorm2d(
      64 -> 64 | ID: 6f424e07-df82-11ee-beb3-fc34974a25ab
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0,