In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import sys

sys.path.append('./mcunet')

from mcunet.gumbel_module.gumbel_net import GumbelMCUNet
from mcunet.gumbel_module.gumbel_layer import MBGumbelInvertedConvLayer, MobileGumbelInvertedResidualBlock, count_conv_gumbel_flops
from mcunet.tinynas.nn.modules import MBInvertedConvLayer
from mcunet.tinynas.nn.networks import MobileInvertedResidualBlock
from mcunet.model_zoo import build_model

from mcunet.utils import MyModule, MyNetwork, SEModule, build_activation, get_same_padding, sub_filter_start_end, rm_bn_from_net, get_deep_attr, has_deep_attr
from mcunet.tinynas.nn.modules import ZeroLayer, set_layer_from_config

from fvcore.nn import FlopCountAnalysis

In [None]:
ori_model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)
gubmel_config = {'global_expand_ratio_list':[1,3,4,5,6], 'global_kernel_size_list':[3,5,7], 'gumbel_feature_extract_block_idx':2}
gumbel_model = GumbelMCUNet.build_from_config(ori_model.config, gubmel_config)
gumbel_model.load_pretrained_mcunet_param(ori_model)

In [None]:
for n, p in ori_model.named_parameters():
    if has_deep_attr(gumbel_model, n):
        gumbel_param = get_deep_attr(gumbel_model, n).data
        print(n)
        print((p-gumbel_param).sum())

In [None]:
from src import distrib
from src import dataset
from src.trainer import Trainer
from easydict import EasyDict as edict

In [None]:
args = edict()
args.db = edict()
args.db.name = 'imagenet'
args.db.root = '/dataset/ImageNet/Classification/'
args.flops_penalty = 0.0
args.lr_sched = None
args.device = 0
args.epochs = 0
args.max_norm = 0.5
args.continue_from = False
args.checkpoint = None
args.history_file = None
args.restart=False
args.num_prints=10
args.mixed=False

In [None]:
train_dataset, test_dataset, num_class = dataset.get_loader(args, img_resize=160)
data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
data= {'tr':train_dataset, 'tt':data_loader}

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = None
original_flops = 100000
ori_trainer = Trainer(data, ori_model.cuda(), criterion, optimizer, args, original_flops)
gumbel_trainer = Trainer(data, gumbel_model.cuda(), criterion, optimizer, args, original_flops)


In [None]:
ori_trainer.test(ori_model=True)

In [12]:
gumbel_trainer.test(ori_model=True)

forward original
loss : 0.70, acc : 79.69
forward original
loss : 0.76, acc : 79.49
forward original
loss : 0.64, acc : 82.81
forward original
loss : 0.62, acc : 83.98
forward original
loss : 0.62, acc : 84.06
forward original
loss : 0.68, acc : 82.88
forward original
loss : 0.82, acc : 79.63
forward original
loss : 0.86, acc : 77.78
forward original
loss : 0.91, acc : 77.17
forward original
loss : 0.94, acc : 76.17
forward original
loss : 0.97, acc : 75.36
forward original
loss : 1.02, acc : 74.22
forward original
loss : 1.07, acc : 73.26
forward original
loss : 1.07, acc : 73.07
forward original
loss : 1.07, acc : 72.99
forward original
loss : 1.08, acc : 73.12
forward original
loss : 1.05, acc : 73.74
forward original
loss : 1.02, acc : 74.59
forward original
loss : 0.99, acc : 75.29
forward original
loss : 0.97, acc : 75.55
forward original
loss : 0.98, acc : 75.73
forward original
loss : 0.97, acc : 75.87
forward original
loss : 0.97, acc : 75.99
forward original
loss : 0.98, acc 

In [None]:
gumbel_model.eval()
ori_model.eval()
for n, m in gumbel_model.named_modules():
    if has_deep_attr(ori_model, n):
        if isinstance(m, MobileGumbelInvertedResidualBlock):
            ori_m = get_deep_attr(ori_model, n)
            m = m.cuda()
            ori_m = ori_m.cuda()
            in_c = ori_m.mobile_inverted_conv.in_channels
            input_rand_tensor = torch.randn(1, in_c, 16, 16).cuda()
            ori_out = ori_m(input_rand_tensor)
            out = m(input_rand_tensor)
            print(f"module name : {n}")
            print("distance : ", ori_out-out)
            print("=="*20)

In [None]:
inputs_test = torch.randn(1,3, 160, 160).cuda()
ori_model = ori_model.cuda()
gumbel_model = gumbel_model.cuda()
ori_output = ori_model.first_conv(inputs_test)
gumbel_output = gumbel_model.first_conv(inputs_test)
print("difference : ", ori_output - gumbel_output)

for i in range(len(ori_model.blocks)):
    ori_output = ori_model.blocks[i](ori_output)
    gumbel_output = gumbel_model.blocks[i](gumbel_output)
    print(f"{i}'s difference : ", ori_output - gumbel_output)

In [None]:
ori_output

In [None]:
gumbel_output

In [None]:
ori_output = ori_output.mean(3).mean(2)
gumbel_output = gumbel_output.mean(3).mean(2)
ori_output = ori_model.classifier(ori_output)
gumbel_output = gumbel_model.classifier(gumbel_output)
print("difference : ", (ori_output - gumbel_output))