In [1]:
''' setting before run. every notebook should include this code. '''
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import sys

_r = os.getcwd().split('/')
_p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1])
print('Change dir from %s to %s' % (os.getcwd(), _p))
os.chdir(_p)
sys.path.append(_p)

from config import parse_from_dict
parse_from_dict({
    "base": {
        "task_name": "resnet56_finetune",
        "cuda": True,
        "seed": 0,
        "checkpoint_path": "",
        "epoch": 0,
        "multi_gpus": True,
        "fp16": False
    },
    "model": {
        "name": "cifar.resnet56",
        "num_class": 10,
        "pretrained": False
    },
    "train": {
        "trainer": "normal",
        "max_epoch": 160,
        "optim": "sgd",
        "steplr": [
            [80, 0.1],
            [120, 0.01],
            [160, 0.001]
        ],
        "weight_decay": 5e-4,
        "momentum": 0.9,
        "nesterov": False
    },
    "data": {
        "type": "cifar10",
        "shuffle": True,
        "batch_size": 64,
        "test_batch_size": 128,
        "num_workers": 8
    },
    "loss": {
        "criterion": "softmax"
    },
    "gbn": {
        "finetune_epoch": 40,
        "lr_min": 1e-3,
        "lr_max": 1e-2
    }
})
from config import cfg

Change dir from /root/code/gate-decorator-pruning/run/resnet-56 to /root/code/gate-decorator-pruning
Parsing config file...
** Assert in demo mode. **


usage: ipykernel_launcher.py [-h] [--config CONFIG]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-ad976665-63c8-4e67-af41-4200de9de48d.json


In [2]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim

from logger import logger
from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr
from models import get_model
from utils import dotdict

from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver
from prune.utils import analyse_model, finetune

In [3]:
set_seeds()
pack = recover_pack()

==> Preparing Cifar10 data..
Files already downloaded and verified
Files already downloaded and verified


  init.kaiming_normal(m.weight)


In [4]:
GBNs = GatedBatchNorm2d.transform(pack.net)
for gbn in GBNs:
    gbn.extract_from_bn()

-------

#### 70% flops reduced

In [5]:
model_dict = torch.load('./ckps/resnet56_cifar10_70percent_flops_reduced.ckp', map_location='cpu' if not cfg.base.cuda else 'cuda')
pack.net.module.load_state_dict(model_dict)

In [6]:
_ = Conv2dObserver.transform(pack.net.module)
pack.net.module.linear = FinalLinearObserver(pack.net.module.linear)
Meltable.observe(pack, 0.001)
Meltable.melt_all(pack.net)

In [7]:
pack.optimizer = optim.SGD(
    pack.net.parameters(),
    lr=1,
    momentum=cfg.train.momentum,
    weight_decay=cfg.train.weight_decay,
    nesterov=cfg.train.nesterov
)

In [8]:
pack.trainer.test(pack)

{'test_loss': 0.2779307614398908, 'acc@1': 91.79193037974683}

In [9]:
_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

100%|█████████████████████████████████████████████████████████████| 782/782 [00:41<00:00, 18.75it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08736122048953, 'epoch_time': 41.72056436538696, 'test_loss': 0.2590870365877695, 'acc@1': 92.41495253164557, 'LR': 0.0014494245524296675}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:40<00:00, 19.34it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08371398926181409, 'epoch_time': 40.44017148017883, 'test_loss': 0.27484140394231943, 'acc@1': 92.23694620253164, 'LR': 0.0018994245524296676}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.08it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07866653482265332, 'epoch_time': 38.942352056503296, 'test_loss': 0.26378578221118903, 'acc@1': 92.43473101265823, 'LR': 0.0023494245524296677}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.16it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08044928345648224, 'epoch_time': 38.80302429199219, 'test_loss': 0.27720553116707863, 'acc@1': 91.97982594936708, 'LR': 0.0027994245524296676}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.78it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08015105671361279, 'epoch_time': 39.533297061920166, 'test_loss': 0.28315314490206633, 'acc@1': 92.04905063291139, 'LR': 0.003249424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.12it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0777250489367701, 'epoch_time': 38.87986135482788, 'test_loss': 0.31421341539560993, 'acc@1': 91.2381329113924, 'LR': 0.0036994245524296673}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:41<00:00, 19.04it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07942619433869487, 'epoch_time': 41.07600688934326, 'test_loss': 0.29303259298771245, 'acc@1': 91.9501582278481, 'LR': 0.004149424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.85it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07626302736094388, 'epoch_time': 39.41067028045654, 'test_loss': 0.3127180535959292, 'acc@1': 91.25791139240506, 'LR': 0.004599424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:40<00:00, 19.50it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07806578878780156, 'epoch_time': 40.10741591453552, 'test_loss': 0.3028579484435576, 'acc@1': 91.74248417721519, 'LR': 0.005049424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.89it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07613613003450434, 'epoch_time': 39.33325266838074, 'test_loss': 0.2798599084909958, 'acc@1': 92.00949367088607, 'LR': 0.005499424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.11it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08179526460711914, 'epoch_time': 43.185001611709595, 'test_loss': 0.3468982267983352, 'acc@1': 90.59533227848101, 'LR': 0.005949424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.25it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08035189525969803, 'epoch_time': 38.62661695480347, 'test_loss': 0.34918082486602325, 'acc@1': 90.67444620253164, 'LR': 0.006399424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.94it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07886561876177178, 'epoch_time': 39.22112989425659, 'test_loss': 0.28857195339625397, 'acc@1': 92.0193829113924, 'LR': 0.006849424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.83it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07900707956279635, 'epoch_time': 39.44291973114014, 'test_loss': 0.33388656811623635, 'acc@1': 90.97112341772151, 'LR': 0.007299424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:41<00:00, 18.83it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08019499626496564, 'epoch_time': 41.54308867454529, 'test_loss': 0.6180328722996048, 'acc@1': 85.38370253164557, 'LR': 0.0077494245524296675}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.96it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08141366296621692, 'epoch_time': 39.18882393836975, 'test_loss': 0.3626729026436806, 'acc@1': 90.32832278481013, 'LR': 0.008199424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:40<00:00, 19.51it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0805086620923732, 'epoch_time': 40.09125351905823, 'test_loss': 0.31225008045948005, 'acc@1': 91.49525316455696, 'LR': 0.008649424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.51it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0827122603964699, 'epoch_time': 38.12613868713379, 'test_loss': 0.4951764718641209, 'acc@1': 87.10443037974683, 'LR': 0.009099424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.42it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08133491148214664, 'epoch_time': 38.308671712875366, 'test_loss': 0.3105162973456745, 'acc@1': 91.18868670886076, 'LR': 0.009549424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:40<00:00, 19.38it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08178044203907023, 'epoch_time': 40.36092281341553, 'test_loss': 0.3464855530971213, 'acc@1': 90.76344936708861, 'LR': 0.009999424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.69it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08133753058036118, 'epoch_time': 39.73035407066345, 'test_loss': 0.2971889487927473, 'acc@1': 91.82159810126582, 'LR': 0.009550575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.14it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07861012025543339, 'epoch_time': 38.830283403396606, 'test_loss': 0.37870327260675307, 'acc@1': 90.19976265822785, 'LR': 0.009100575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.85it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0759652434702953, 'epoch_time': 39.40379762649536, 'test_loss': 0.400169717077213, 'acc@1': 89.65585443037975, 'LR': 0.008650575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.22it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0702243430201736, 'epoch_time': 38.689780473709106, 'test_loss': 0.30230171638953535, 'acc@1': 91.8117088607595, 'LR': 0.008200575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 20.05it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.06342756760585339, 'epoch_time': 39.01131224632263, 'test_loss': 0.3268708971482289, 'acc@1': 91.65348101265823, 'LR': 0.007750575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.10it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.05761700121642988, 'epoch_time': 38.91428470611572, 'test_loss': 0.3085207801453675, 'acc@1': 91.87104430379746, 'LR': 0.0073005754475703325}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.25it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.05695188578094363, 'epoch_time': 38.62928342819214, 'test_loss': 0.31053954447749293, 'acc@1': 91.91060126582279, 'LR': 0.006850575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.49it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.055821682352696536, 'epoch_time': 42.30193614959717, 'test_loss': 0.2985222577294217, 'acc@1': 92.13805379746836, 'LR': 0.006400575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.34it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.05252815320935396, 'epoch_time': 38.45939493179321, 'test_loss': 0.2973584127199801, 'acc@1': 92.06882911392405, 'LR': 0.005950575447570334}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.57it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.047342439946692315, 'epoch_time': 38.02919960021973, 'test_loss': 0.3015447472092472, 'acc@1': 92.17761075949367, 'LR': 0.005500575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.89it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.04624745885238928, 'epoch_time': 39.32641816139221, 'test_loss': 0.30456703231681753, 'acc@1': 92.37539556962025, 'LR': 0.005050575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.45it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.04509912198290343, 'epoch_time': 38.251078367233276, 'test_loss': 0.2981601983686037, 'acc@1': 92.51384493670886, 'LR': 0.004600575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.61it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0413818648442283, 'epoch_time': 39.89181065559387, 'test_loss': 0.28899698374392113, 'acc@1': 92.69185126582279, 'LR': 0.0041505754475703325}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:44<00:00, 17.73it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.040663603774231415, 'epoch_time': 44.10174918174744, 'test_loss': 0.2909717359874822, 'acc@1': 92.7314082278481, 'LR': 0.003700575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:43<00:00, 18.06it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03695526665738781, 'epoch_time': 43.31573486328125, 'test_loss': 0.3589607465870773, 'acc@1': 90.88212025316456, 'LR': 0.0032505754475703327}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:42<00:00, 18.57it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.037520130981912696, 'epoch_time': 42.11086893081665, 'test_loss': 0.2894347329871564, 'acc@1': 92.75118670886076, 'LR': 0.002800575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 20.03it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.033381814075171795, 'epoch_time': 39.04433298110962, 'test_loss': 0.28974121846730194, 'acc@1': 92.58306962025317, 'LR': 0.002350575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.70it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03265679780574863, 'epoch_time': 39.70557951927185, 'test_loss': 0.2891138912756232, 'acc@1': 92.76107594936708, 'LR': 0.0019005754475703326}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.25it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03141301143866823, 'epoch_time': 38.631301403045654, 'test_loss': 0.2897222333505184, 'acc@1': 92.7314082278481, 'LR': 0.001450575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.10it/s]


{'train_loss': 0.02813051395651782, 'epoch_time': 38.917781352996826, 'test_loss': 0.28455836571092846, 'acc@1': 92.85007911392405, 'LR': 0.0010005754475703327}


In [10]:
pack.net

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(1, 16, kernel_size=(