In [1]:
''' setting before run. every notebook should include this code. '''
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import sys

_r = os.getcwd().split('/')
_p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1])
print('Change dir from %s to %s' % (os.getcwd(), _p))
os.chdir(_p)
sys.path.append(_p)

from config import parse_from_dict
parse_from_dict({
    "base": {
        "task_name": "resnet56_finetune",
        "cuda": True,
        "seed": 1,
        "checkpoint_path": "",
        "epoch": 0,
        "multi_gpus": True,
        "fp16": False
    },
    "model": {
        "name": "cifar.resnet56",
        "num_class": 10,
        "pretrained": False
    },
    "train": {
        "trainer": "normal",
        "max_epoch": 160,
        "optim": "sgd",
        "steplr": [
            [80, 0.1],
            [120, 0.01],
            [160, 0.001]
        ],
        "weight_decay": 5e-4,
        "momentum": 0.9,
        "nesterov": False
    },
    "data": {
        "type": "cifar10",
        "shuffle": True,
        "batch_size": 64,
        "test_batch_size": 128,
        "num_workers": 8
    },
    "loss": {
        "criterion": "softmax"
    },
    "gbn": {
        "finetune_epoch": 40,
        "lr_min": 1e-3,
        "lr_max": 1e-2
    }
})
from config import cfg

Change dir from /root/code/gate-decorator-pruning/run/resnet-56 to /root/code/gate-decorator-pruning
Parsing config file...
** Assert in demo mode. **


usage: ipykernel_launcher.py [-h] [--config CONFIG]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-5c91900f-5696-47bc-a457-c038486324d5.json


In [2]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim

from logger import logger
from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr
from models import get_model
from utils import dotdict

from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver
from prune.utils import analyse_model, finetune

In [3]:
set_seeds()
pack = recover_pack()

==> Preparing Cifar10 data..
Files already downloaded and verified
Files already downloaded and verified


  init.kaiming_normal(m.weight)


In [4]:
GBNs = GatedBatchNorm2d.transform(pack.net)
for gbn in GBNs:
    gbn.extract_from_bn()

-------

#### 70% flops reduced

In [5]:
model_dict = torch.load('./ckps/resnet56_cifar10_70percent_flops_reduced.ckp', map_location='cpu' if not cfg.base.cuda else 'cuda')
pack.net.module.load_state_dict(model_dict)

In [6]:
_ = Conv2dObserver.transform(pack.net.module)
pack.net.module.linear = FinalLinearObserver(pack.net.module.linear)
Meltable.observe(pack, 0.001)
Meltable.melt_all(pack.net)

In [7]:
pack.optimizer = optim.SGD(
    pack.net.parameters(),
    lr=1,
    momentum=cfg.train.momentum,
    weight_decay=cfg.train.weight_decay,
    nesterov=cfg.train.nesterov
)

In [8]:
pack.trainer.test(pack)

{'test_loss': 0.2779307614398908, 'acc@1': 91.79193037974683}

In [9]:
_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.48it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0859906351446267, 'epoch_time': 38.19228506088257, 'test_loss': 0.26257172123163564, 'acc@1': 92.51384493670886, 'LR': 0.0014494245524296675}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.48it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08416774630775233, 'epoch_time': 38.19319152832031, 'test_loss': 0.2739955025571811, 'acc@1': 92.16772151898734, 'LR': 0.0018994245524296676}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.77it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08062715783639027, 'epoch_time': 37.6591157913208, 'test_loss': 0.2670000918869731, 'acc@1': 92.32594936708861, 'LR': 0.0023494245524296677}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.67it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07984122698721678, 'epoch_time': 37.83485126495361, 'test_loss': 0.26516500247430197, 'acc@1': 92.41495253164557, 'LR': 0.0027994245524296676}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.26it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07775474074856399, 'epoch_time': 38.60597896575928, 'test_loss': 0.29833116014546984, 'acc@1': 91.73259493670886, 'LR': 0.003249424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.33it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07926799653722044, 'epoch_time': 38.47671055793762, 'test_loss': 0.3360251865432232, 'acc@1': 90.44699367088607, 'LR': 0.0036994245524296673}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.49it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07778495687353032, 'epoch_time': 38.18007206916809, 'test_loss': 0.26586916071327427, 'acc@1': 92.36550632911393, 'LR': 0.004149424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.71it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07778559496049839, 'epoch_time': 37.775842905044556, 'test_loss': 0.29883205683170994, 'acc@1': 91.51503164556962, 'LR': 0.004599424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.57it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07901717988712251, 'epoch_time': 38.015756368637085, 'test_loss': 0.3086648947453197, 'acc@1': 91.36669303797468, 'LR': 0.005049424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.47it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07954291995529018, 'epoch_time': 38.21129751205444, 'test_loss': 0.288397289246698, 'acc@1': 91.94026898734177, 'LR': 0.005499424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.29it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07939148757158948, 'epoch_time': 38.53926396369934, 'test_loss': 0.3268880275419996, 'acc@1': 91.18868670886076, 'LR': 0.005949424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.22it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07956523685465994, 'epoch_time': 38.67717361450195, 'test_loss': 0.5103803397733954, 'acc@1': 86.28362341772151, 'LR': 0.006399424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.53it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07924813777684708, 'epoch_time': 38.10012769699097, 'test_loss': 0.29654051816161675, 'acc@1': 91.9501582278481, 'LR': 0.006849424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.10it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0800099972768894, 'epoch_time': 38.92156362533569, 'test_loss': 0.30265500896339176, 'acc@1': 91.86115506329114, 'LR': 0.007299424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.37it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08187006331999283, 'epoch_time': 38.397090911865234, 'test_loss': 0.32993217642548717, 'acc@1': 91.06012658227849, 'LR': 0.0077494245524296675}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.75it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07584094027023944, 'epoch_time': 37.685349225997925, 'test_loss': 0.5898690853692308, 'acc@1': 86.13528481012658, 'LR': 0.008199424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.76it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0824220284981572, 'epoch_time': 37.674084186553955, 'test_loss': 0.3141148318595524, 'acc@1': 91.62381329113924, 'LR': 0.008649424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.24it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08370529632548541, 'epoch_time': 38.64585471153259, 'test_loss': 0.6009437015539483, 'acc@1': 85.21558544303798, 'LR': 0.009099424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.41it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08132910733218388, 'epoch_time': 38.31469988822937, 'test_loss': 0.3225360195848006, 'acc@1': 91.15901898734177, 'LR': 0.009549424552429667}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.21it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08203476231516627, 'epoch_time': 38.69735908508301, 'test_loss': 0.38888421239732185, 'acc@1': 89.71518987341773, 'LR': 0.009999424552429668}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.66it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.08073182081055763, 'epoch_time': 37.864201068878174, 'test_loss': 0.43324236269993116, 'acc@1': 88.92405063291139, 'LR': 0.009550575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.63it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07775000311300882, 'epoch_time': 37.90449666976929, 'test_loss': 0.3563081159999099, 'acc@1': 90.51621835443038, 'LR': 0.009100575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.09it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.07238657785403302, 'epoch_time': 38.93989396095276, 'test_loss': 0.3005907417664045, 'acc@1': 92.15783227848101, 'LR': 0.008650575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.40it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.06765160214660874, 'epoch_time': 38.33421015739441, 'test_loss': 0.32201182285818875, 'acc@1': 91.69303797468355, 'LR': 0.008200575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.83it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.06425989445899148, 'epoch_time': 37.55430293083191, 'test_loss': 0.34474485527865495, 'acc@1': 91.13924050632912, 'LR': 0.007750575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.22it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.0618073859101976, 'epoch_time': 38.68104338645935, 'test_loss': 0.33500105953669246, 'acc@1': 91.19857594936708, 'LR': 0.0073005754475703325}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.82it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.06024209699114723, 'epoch_time': 37.568342208862305, 'test_loss': 0.29607175677260267, 'acc@1': 92.1182753164557, 'LR': 0.006850575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.38it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.05481064845057552, 'epoch_time': 38.38657855987549, 'test_loss': 0.3112023651977129, 'acc@1': 91.91060126582279, 'LR': 0.006400575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.24it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.05167926366076521, 'epoch_time': 38.64375114440918, 'test_loss': 0.2935829552102692, 'acc@1': 92.48417721518987, 'LR': 0.005950575447570334}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.11it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.04938280738680564, 'epoch_time': 38.90058922767639, 'test_loss': 0.3146475397899181, 'acc@1': 92.02927215189874, 'LR': 0.005500575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.26it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.046452013956730626, 'epoch_time': 38.61159324645996, 'test_loss': 0.41355924987340276, 'acc@1': 90.26898734177215, 'LR': 0.005050575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.13it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.04587310114565789, 'epoch_time': 38.850093364715576, 'test_loss': 0.2964700212395644, 'acc@1': 92.69185126582279, 'LR': 0.004600575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.51it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.04220182882845783, 'epoch_time': 38.12518572807312, 'test_loss': 0.3129058468002307, 'acc@1': 92.00949367088607, 'LR': 0.0041505754475703325}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.75it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03983215707094621, 'epoch_time': 37.700111865997314, 'test_loss': 0.28966658213470553, 'acc@1': 92.7314082278481, 'LR': 0.003700575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.80it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03746543604108836, 'epoch_time': 37.60990786552429, 'test_loss': 0.3142848335489442, 'acc@1': 92.24683544303798, 'LR': 0.0032505754475703327}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.78it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03424491639942159, 'epoch_time': 37.644187688827515, 'test_loss': 0.29440756208157237, 'acc@1': 92.51384493670886, 'LR': 0.002800575447570332}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.41it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.03410469871633651, 'epoch_time': 38.31650257110596, 'test_loss': 0.3006515582151051, 'acc@1': 92.71162974683544, 'LR': 0.002350575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.88it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.030563229463918283, 'epoch_time': 39.35120964050293, 'test_loss': 0.29495988408975965, 'acc@1': 92.83030063291139, 'LR': 0.0019005754475703326}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.54it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.02765191834696266, 'epoch_time': 38.081024408340454, 'test_loss': 0.2885598743075057, 'acc@1': 92.88963607594937, 'LR': 0.001450575447570333}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.33it/s]


{'train_loss': 0.028634040556905216, 'epoch_time': 38.46471190452576, 'test_loss': 0.29290160272694843, 'acc@1': 92.84018987341773, 'LR': 0.0010005754475703327}


In [10]:
_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_min, T=10)

100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.68it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.028377088579966132, 'epoch_time': 37.85368323326111, 'test_loss': 0.2879510103147241, 'acc@1': 92.97863924050633, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.68it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.02782191441911261, 'epoch_time': 37.820202112197876, 'test_loss': 0.2841554020024553, 'acc@1': 93.09731012658227, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.36it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.02697309303095045, 'epoch_time': 38.4120237827301, 'test_loss': 0.2892748922864093, 'acc@1': 92.87974683544304, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.24it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.02626397204883111, 'epoch_time': 38.63919520378113, 'test_loss': 0.29763409438767013, 'acc@1': 92.87974683544304, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.90it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.026684973796215053, 'epoch_time': 39.3054313659668, 'test_loss': 0.2863810823499402, 'acc@1': 92.98852848101266, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.51it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.02642264375296395, 'epoch_time': 38.13243556022644, 'test_loss': 0.29373643547296524, 'acc@1': 92.87974683544304, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.20it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.026813462883939067, 'epoch_time': 38.71113991737366, 'test_loss': 0.28625739299798314, 'acc@1': 93.05775316455696, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:37<00:00, 20.62it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.025754430719539333, 'epoch_time': 37.92336654663086, 'test_loss': 0.2933901181515259, 'acc@1': 92.88963607594937, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:38<00:00, 20.45it/s]
  0%|                                                                       | 0/782 [00:00<?, ?it/s]

{'train_loss': 0.025038552658675273, 'epoch_time': 38.24193525314331, 'test_loss': 0.29054723557414885, 'acc@1': 93.09731012658227, 'LR': 0.001}


100%|█████████████████████████████████████████████████████████████| 782/782 [00:39<00:00, 19.80it/s]


{'train_loss': 0.024549294720449106, 'epoch_time': 39.50026035308838, 'test_loss': 0.28829607669311236, 'acc@1': 93.15664556962025, 'LR': 0.001}


In [10]:
pack.net

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(1, 16, kernel_size=(