In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import sys

sys.path.append('./mcunet')

from mcunet.gumbel_module.gumbel_net import GumbelMCUNet
from mcunet.gumbel_module.gumbel_layer import MBGumbelInvertedConvLayer, MobileGumbelInvertedResidualBlock, count_conv_gumbel_flops
from mcunet.tinynas.nn.modules import MBInvertedConvLayer
from mcunet.tinynas.nn.networks import MobileInvertedResidualBlock
from mcunet.model_zoo import build_model

from mcunet.utils import MyModule, MyNetwork, SEModule, build_activation, get_same_padding, sub_filter_start_end, rm_bn_from_net, set_deep_attr, get_deep_attr, has_deep_attr
from mcunet.tinynas.nn.modules import ZeroLayer, set_layer_from_config

from fvcore.nn import FlopCountAnalysis

In [37]:
ori_model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)
gubmel_config = {'global_expand_ratio_list':[1,3,4,5,6], 'global_kernel_size_list':[3,5,7], 'gumbel_feature_extract_block_idx':2}
gumbel_model = GumbelMCUNet.build_from_config(ori_model.config, gubmel_config)
gumbel_model.load_pretrained_mcunet_param(ori_model)

load pretrained mcu model parameter to gumbel net
before num_batches_tracked  first_conv.bn.num_batches_tracked tensor(774245)
after : first_conv.bn.num_batches_tracked tensor(0)
before num_batches_tracked  blocks.0.mobile_inverted_conv.depth_conv.bn.num_batches_tracked tensor(774245)
after : blocks.0.mobile_inverted_conv.depth_conv.bn.num_batches_tracked tensor(0)
before num_batches_tracked  blocks.0.mobile_inverted_conv.point_linear.bn.num_batches_tracked tensor(774245)
after : blocks.0.mobile_inverted_conv.point_linear.bn.num_batches_tracked tensor(0)
before num_batches_tracked  blocks.1.mobile_inverted_conv.inverted_bottleneck.bn.num_batches_tracked tensor(60095)
after : blocks.1.mobile_inverted_conv.inverted_bottleneck.bn.num_batches_tracked tensor(0)
before num_batches_tracked  blocks.1.mobile_inverted_conv.depth_conv.bn.num_batches_tracked tensor(60095)
after : blocks.1.mobile_inverted_conv.depth_conv.bn.num_batches_tracked tensor(0)
before num_batches_tracked  blocks.1.mobile_i

In [16]:
from src import distrib
from src import dataset
from src.trainer import Trainer
from easydict import EasyDict as edict

In [17]:
args = edict()
args.db = edict()
args.db.name = 'imagenet'
args.db.root = '/dataset/ImageNet/Classification/'
args.flops_penalty = 0.0
args.lr_sched = None
args.device = 0
args.epochs = 0
args.max_norm = 0.5
args.continue_from = True
args.checkpoint = None
args.history_file = None
args.restart=True
args.num_prints=10
args.mixed=True

In [18]:
train_dataset, test_dataset, num_class = dataset.get_loader(args, img_resize=160)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=4)
data= {'tr':train_loader, 'tt':test_loader}

In [24]:
from mcunet.gumbel_module.gumbel_layer import DynamicGumbelBatchNorm2d

inputs = torch.randn(4, 8, 4, 4)
bn = DynamicGumbelBatchNorm2d(8)
backup_bn = copy.deepcopy(bn)

print(backup_bn.bn.running_mean, backup_bn.bn.running_var)
bn.eval()
bn(inputs)
print(backup_bn.bn.running_mean - bn.bn.running_mean, backup_bn.bn.running_var - bn.bn.running_var)
bn.train()
bn(inputs)

print(backup_bn.bn.running_mean - bn.bn.running_mean, backup_bn.bn.running_var - bn.bn.running_var)
print("re setup bn")
bn = copy.deepcopy(backup_bn)
print(bn.bn.running_mean, bn.bn.running_var)
inputs = torch.randn(4,4,4,4)
bn.eval()
bn(inputs)
print(backup_bn.bn.running_mean - bn.bn.running_mean, backup_bn.bn.running_var - bn.bn.running_var)
bn.train()
bn(inputs)
print(backup_bn.bn.running_mean - bn.bn.running_mean, backup_bn.bn.running_var - bn.bn.running_var)

print("re setup bn")

bn = copy.deepcopy(backup_bn)
print(bn.bn.running_mean, bn.bn.running_var)

bn.train()
inputs = torch.ones(4,8,4,4).to(torch.float32)

for i in range(8):
    bn(inputs[:, :i+1, :, :])
    print(bn.bn.running_mean, bn.bn.running_var)       

tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([-0.0148, -0.0012,  0.0120, -0.0059,  0.0142, -0.0135, -0.0068,  0.0145]) tensor([ 0.0278,  0.0173, -0.0243,  0.0166,  0.0347,  0.0055,  0.0184, -0.0235])
re setup bn
tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([-0.0002,  0.0002,  0.0012, -0.0038,  0.0000,  0.0000,  0.0000,  0.0000]) tensor([ 0.0112, -0.0026,  0.0066,  0.0081,  0.0000,  0.0000,  0.0000,  0.0000])
re setup bn
tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) tensor([0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor([0.1900, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) tensor([0.8100, 0.9000,

In [27]:
from mcunet.gumbel_module.gumbel_layer import DynamicGumbelBatchNorm2d

inputs = torch.randn(4, 8, 4, 4)
bn = DynamicGumbelBatchNorm2d(8)
backup_bn = copy.deepcopy(bn)

print("re setup bn")

bn = copy.deepcopy(backup_bn)
print(bn.bn.running_mean, bn.bn.running_var)

bn.train()
inputs = torch.ones(4,8,4,4).to(torch.float32)

for i in range(8):
    out = F.batch_norm(inputs[:, :i+1, :, :], bn.bn.running_mean[:i+1], bn.bn.running_var[:i+1], bn.bn.weight[:i+1], bn.bn.bias[:i+1], training=True, momentum=0.1, eps=1e-5)
    print(bn.bn.running_mean, bn.bn.running_var)       

re setup bn
tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) tensor([0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor([0.1900, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) tensor([0.8100, 0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor([0.2710, 0.1900, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]) tensor([0.7290, 0.8100, 0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor([0.3439, 0.2710, 0.1900, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000]) tensor([0.6561, 0.7290, 0.8100, 0.9000, 1.0000, 1.0000, 1.0000, 1.0000])
tensor([0.4095, 0.3439, 0.2710, 0.1900, 0.1000, 0.0000, 0.0000, 0.0000]) tensor([0.5905, 0.6561, 0.7290, 0.8100, 0.9000, 1.0000, 1.0000, 1.0000])
tensor([0.4686, 0.4095, 0.3439, 0.2710, 0.1900, 0.1000, 0.0000, 0.0000]) tensor([0.5314, 0.5905, 0.6561, 0.7290, 0.8100, 0.9000, 1.0000, 1.0000])
tensor([0.5217, 0.4686, 0.4095

In [23]:
for n, m in gumbel_model.named_modules():
    if isinstance(m, MBGumbelInvertedConvLayer):
        break

bnn = m.inverted_bottleneck.bn

bnn.num_batches_tracked

tensor(0)

In [None]:
print(bn.bn.training)
print(bn.bn.track_running_stats)
print(bn.bn.num_batches_tracked)

True
True
tensor(1)


### MBGumbelInvertedLayer BatchNorm check

In [None]:
mbgumbel_list = []
for n, m in gumbel_model.named_modules():
    if isinstance(m, MBGumbelInvertedConvLayer):
        mbgumbel_list.append([n,m])
    

In [None]:
for n, m in mbgumbel_list:
    print(f"{n} module test")
    test_module = copy.deepcopy(m)

    test_module.cuda()
    test_module.train()
    inputs = torch.randn(4, test_module.config['in_channels'], 32, 32).cuda()
    out = test_module.forward(inputs) # original output

    expand_index, kernel_index = len(test_module.expand_ratio_list), len(test_module.kernel_size_list)
                
    if expand_index > 1 and kernel_index > 1:
        gumbel_one_hot = torch.zeros((inputs.shape[0], expand_index + kernel_index), device=inputs.device)
        gumbel_one_hot[:, expand_index-1] = 1
        gumbel_one_hot[:, expand_index] = 1
    elif expand_index > 1:
        gumbel_one_hot = torch.zeros((inputs.shape[0], expand_index), device=inputs.device)
        gumbel_one_hot[:, expand_index -1] = 1
    elif kernel_index >1:
        gumbel_one_hot = torch.zeros((inputs.shape[0], kernel_index), device=inputs.device)
        gumbel_one_hot[:, 0] = 1
    else:
        gumbel_one_hot = None

    out_gumbel = test_module.forward(inputs, gumbel_one_hot)

    print("train mode")
    print(out - out_gumbel)

    test_module.eval()
    out = test_module.forward(inputs) # original output
    out_gumbel = test_module.forward(inputs, gumbel_one_hot)
    print("eval mode")
    print(out - out_gumbel)

In [None]:
for n, m in mbgumbel_list:
    print(f"{n} module test")
    test_module = copy.deepcopy(m)

    test_module.cuda()
    test_module.train()
    inputs = torch.randn(4, test_module.config['in_channels'], 32, 32).cuda()
    out = test_module.forward(inputs) # original output

    expand_index, kernel_index = len(test_module.expand_ratio_list), len(test_module.kernel_size_list)
                
    if expand_index > 1 and kernel_index > 1:
        gumbel_one_hot = torch.zeros((inputs.shape[0], expand_index + kernel_index), device=inputs.device)
        gumbel_one_hot[:, expand_index-1] = 1
        gumbel_one_hot[:, expand_index] = 1
    elif expand_index > 1:
        gumbel_one_hot = torch.zeros((inputs.shape[0], expand_index), device=inputs.device)
        gumbel_one_hot[:, expand_index -1] = 1
    elif kernel_index >1:
        gumbel_one_hot = torch.zeros((inputs.shape[0], kernel_index), device=inputs.device)
        gumbel_one_hot[:, 0] = 1
    else:
        gumbel_one_hot = None

    out_gumbel = test_module.forward(inputs, gumbel_one_hot)
    out_gumbel.sum().backward()
    

    print("train mode")
    print(out - out_gumbel)

In [None]:
gumbel_model.cuda()
gumbel_train_model = copy.deepcopy(gumbel_model)

load_from = './outputs/exp_bn_check/log=bn_check/checkpoint.th'
package = torch.load(load_from, 'cpu')
gumbel_train_model.load_state_dict(package['state'])

In [None]:
gumbel_train_model.train()
gumbel_train_model2 = copy.deepcopy(gumbel_train_model)

gumbel_train_model.cuda()
gumbel_train_model2.cuda()
inputs = torch.randn(1,3,160,160).cuda()

g_1= gumbel_train_model.forward_original(inputs)
g_2= gumbel_train_model2.forward_gumbel_approx(inputs)

In [None]:
g_1[0].sum().backward()
g_2[0].sum().backward()

In [None]:
for (n, m), (tn, tm) in zip(gumbel_train_model.named_modules(), gumbel_train_model2.named_modules()):
    if isinstance(tm, nn.BatchNorm2d):
        print(n)
        print((m.weight.data - tm.weight.data).max())
        print((m.bias.data - tm.bias.data).max())
        print((m.weight.grad.data - tm.weight.grad.data).max())
        print((m.bias.grad.data - tm.bias.grad.data).max())
        
        print((m.running_mean.data - tm.running_mean.data).max())
        print((m.running_var.data - tm.running_var.data).max())
        print("===="*10)

In [None]:
test_input =torch.randn(3, 4, 2, 2)
test_bn =nn.BatchNorm2d(4)
test_bn.eval()

out1 = test_bn(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, False, 0.0, 1e-5)

print(out1-out2)

In [47]:
test_input =torch.randn(3, 4, 2, 2)
test_bn =nn.BatchNorm2d(4, momentum=0.0, eps=1e-5)
test_bn.eval()

out1 = test_bn(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, False, 0.0, 1e-5)

print((out1-out2).sum())


test_bn2 = copy.deepcopy(test_bn)
test_bn2.train()
test_bn.train()

out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.0, 1e-5)
print("BN and F.batchnorm output check (train mode)")
print((out1-out2).sum())
out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.0, 1e-5)
print("BN and F.batchnorm output check (train mode 1 iter after)")
print((out1-out2).sum())
out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.0, 1e-5)
print("BN and F.batchnorm output check (train mode 2 iter after)")
print((out1-out2).sum())

out1.sum().backward()
out2.sum().backward()
print("bn weight's gradinet check")
print(test_bn.weight.grad,test_bn2.weight.grad)

out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.0, 1e-5)
print("BN and F.batchnorm output check (train mode 2 iter after and weight update)")
print((out1-out2).sum())

test_bn2.eval()
test_bn.eval()
out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, False, 0.0, 1e-5)
print("BN and F.batchnorm output check (eval mode 2 iter after and weight update)")
print((out1-out2).sum())
print(test_bn.running_mean, test_bn.running_var)
print(test_bn2.running_mean, test_bn2.running_var)

tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (train mode)
tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (train mode 1 iter after)
tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (train mode 2 iter after)
tensor(0., grad_fn=<SumBackward0>)
bn weight's gradinet check
tensor([ 3.2685e-08,  4.1659e-07, -6.5098e-08, -1.5111e-07]) tensor([ 3.2685e-08,  4.1659e-07, -6.5098e-08, -1.5111e-07])
BN and F.batchnorm output check (train mode 2 iter after and weight update)
tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (eval mode 2 iter after and weight update)
tensor(0., grad_fn=<SumBackward0>)
tensor([0., 0., 0., 0.]) tensor([1., 1., 1., 1.])
tensor([0., 0., 0., 0.]) tensor([1., 1., 1., 1.])


In [44]:
test_input =torch.randn(3, 4, 2, 2)
test_bn =nn.BatchNorm2d(4, momentum=0.1, eps=1e-5)
test_bn.eval()

out1 = test_bn(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, False, 0.0, 1e-5) # eval mode check

print((out1-out2).sum())


test_bn2 = copy.deepcopy(test_bn)
test_bn2.train()
test_bn.train()

out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.1, 1e-5)
print("BN and F.batchnorm output check (train mode)")
print((out1-out2).sum())
out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.1, 1e-5)
print("BN and F.batchnorm output check (train mode 1 iter after)")
print((out1-out2).sum())
out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.1, 1e-5)
print("BN and F.batchnorm output check (train mode 2 iter after)")
print((out1-out2).sum())

out1.sum().backward()
out2.sum().backward()
print("bn weight's gradinet check")
print(test_bn.weight.grad,test_bn2.weight.grad)

out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, True, 0.1, 1e-5)
print("BN and F.batchnorm output check (train mode 2 iter after and weight update)")
print((out1-out2).sum())

test_bn2.eval()
test_bn.eval()
out1 = test_bn2(test_input)
out2 = F.batch_norm(test_input, test_bn.running_mean, test_bn.running_var, test_bn.weight, test_bn.bias, False, 0.1, 1e-5)
print("BN and F.batchnorm output check (eval mode 2 iter after and weight update)")
print((out1-out2).sum())
print(test_bn.running_mean, test_bn.running_var)
print(test_bn2.running_mean, test_bn2.running_var)

tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (train mode)
tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (train mode 1 iter after)
tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (train mode 2 iter after)
tensor(0., grad_fn=<SumBackward0>)
bn weight's gradinet check
tensor([-1.7993e-07,  1.6251e-07, -1.3427e-07, -1.2154e-07]) tensor([-1.7993e-07,  1.6251e-07, -1.3427e-07, -1.2154e-07])
BN and F.batchnorm output check (train mode 2 iter after and weight update)
tensor(0., grad_fn=<SumBackward0>)
BN and F.batchnorm output check (eval mode 2 iter after and weight update)
tensor(0., grad_fn=<SumBackward0>)
tensor([ 0.0113, -0.0004, -0.0224, -0.1097]) tensor([0.9964, 0.9134, 1.0304, 0.8591])
tensor([ 0.0113, -0.0004, -0.0224, -0.1097]) tensor([0.9964, 0.9134, 1.0304, 0.8591])


In [None]:
print(test_bn2.running_mean, test_bn2.running_var)

In [None]:
print(test_bn.running_mean, test_bn.running_var)

In [None]:
for (n, m), (tn, tm) in zip(gumbel_model.named_modules(), gumbel_train_model.named_modules()):
    if isinstance(tm, nn.BatchNorm2d):
        print(n)
        print((m.weight.data - tm.weight.data).max())
        print((m.bias.data - tm.bias.data).max())
        print((m.running_mean.data - tm.running_mean.data).max())
        print((m.running_var.data - tm.running_var.data).max())
        print("===="*10)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = None
original_flops = 100000
gumbel_trainer = Trainer(data, gumbel_model.cuda(), criterion, optimizer, args, original_flops)


In [None]:
gumbel_trainer.test(ori_model=True)

In [None]:
for data, label in test_loader:
    data = data[:32].cuda()
    label = label[:32].cuda()
    break

In [None]:
class Hook:
    def __init__(self, name, module):
        self.name = name
        self.module = module
        self.hook = module.register_forward_hook(self.hook_fn)
        self.output = None
    def hook_fn(self, module, input, output):
        self.module = module
        self.input = input
        self.output = output
    def close(self):
        self.hook.remove()

In [None]:
ori_hook_list = []
ori_model.cuda()
for n, m in ori_model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        ori_hook_list.append(Hook(n, m))

ori_model(data).max(dim=1)

In [None]:
gumbel_hook_list = []
gumbel_model.cuda()
for n, m in gumbel_model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        gumbel_hook_list.append(Hook(n, m))

gumbel_model.forward(data)[0].max(1)

In [None]:
for i, hook in enumerate(ori_hook_list):
    print(hook.name)
    #print("original model's bn statistics :", hook.module.running_mean, hook.module.running_var)
    input_data = hook.input[0][:10]
    n, c, h, w = input_data.shape
    input_mean, input_std = input_data.reshape(n, c, -1).mean(dim=-1), input_data.reshape(n, c, -1).var(dim=-1)
    diff_input_mean = input_mean - hook.module.running_mean
    diff_input_std = input_std - hook.module.running_var
    print("bn and inputs difference mean statistics : ")
    print(f"diff mean's statistics :\n avg : {diff_input_mean.mean():.6f}, std : {diff_input_mean.std():.6f}, max : {diff_input_mean.max():.6f}, min : {diff_input_mean.min():.6f}")
    print(f"diff var's statistics :\n avg : {diff_input_std.mean():.6f}, std : {diff_input_std.std():.6f}, max : {diff_input_std.max():.6f}, min : {diff_input_std.min():.6f}")
    print("=="*30)
    if i > 5:
        break
    

In [None]:
for i, hook in enumerate(gumbel_hook_list):
    print("gumbel model statistics")
    print(hook.name)
    #print("original model's bn statistics :", hook.module.running_mean, hook.module.running_var)
    input_data = hook.input[0][:10]
    n, c, h, w = input_data.shape
    input_mean, input_std = input_data.reshape(n, c, -1).mean(dim=-1), input_data.reshape(n, c, -1).var(dim=-1)
    diff_input_mean = input_mean - hook.module.running_mean
    diff_input_std = input_std - hook.module.running_var
    print("bn and inputs difference mean statistics : ")
    print(f"diff mean's statistics :\n avg : {diff_input_mean.mean():.6f}, std : {diff_input_mean.std():.6f}, max : {diff_input_mean.max():.6f}, min : {diff_input_mean.min():.6f}")
    print(f"diff var's statistics :\n avg : {diff_input_std.mean():.6f}, std : {diff_input_std.std():.6f}, max : {diff_input_std.max():.6f}, min : {diff_input_std.min():.6f}")
    print("=="*30)
    if i > 5:
        break
    

In [None]:
for n, m in gumbel_model.named_parameters():
    print(n, m.shape)

In [None]:
gumbel_trainer.test(ori_model=True)

In [None]:
m.config

In [None]:
g = torch.randn(10)
h = 5
print(g)
g.unsqueeze(0).repeat(h, 1)

In [None]:
gumbel_model.eval()
ori_model.eval()
for n, m in gumbel_model.named_modules():
    if has_deep_attr(ori_model, n):
        if isinstance(m, MobileGumbelInvertedResidualBlock):
            ori_m = get_deep_attr(ori_model, n)
            m = m.cuda()
            ori_m = ori_m.cuda()
            in_c = ori_m.mobile_inverted_conv.in_channels
            input_rand_tensor = torch.randn(1, in_c, 16, 16).cuda()
            ori_out = ori_m(input_rand_tensor)
            expand_ratio_list, kernel_size_list = m.mobile_inverted_conv.expand_ratio_list, m.mobile_inverted_conv.kernel_size_list
            if len(expand_ratio_list) == 1 and len(kernel_size_list) == 1:
                gumbel_idx = None
            elif len(expand_ratio_list) > 1 and len(kernel_size_list) == 1:
                gumbel_idx = torch.zeros(len(expand_ratio_list)).long().to(input_rand_tensor.device)
                gumbel_idx[len(expand_ratio_list)-1] = 1
            
            elif len(expand_ratio_list) == 1 and len(kernel_size_list) > 1:
                gumbel_idx = torch.zeros(len(kernel_size_list)).long().to(input_rand_tensor.device)
                gumbel_idx[0] = 1 
            else:
                gumbel_idx = torch.zeros(len(expand_ratio_list)+len(kernel_size_list)).long().to(input_rand_tensor.device)
                gumbel_idx[len(expand_ratio_list)-1] = 1
                gumbel_idx[len(expand_ratio_list)] = 1
            
            gumbel_idx = gumbel_idx.unsqueeze(0).repeat(input_rand_tensor.shape[0], 1)
            out = m(input_rand_tensor, gumbel_idx)
            print(f"module name : {n}")
            print("distance : ", (ori_out-out).sum())
            print("=="*20)
            

In [None]:
inputs_test = torch.randn(1,3, 160, 160).cuda()
ori_model = ori_model.cuda()
gumbel_model = gumbel_model.cuda()
ori_output = ori_model.first_conv(inputs_test)
gumbel_output = gumbel_model.first_conv(inputs_test)
print("difference : ", ori_output - gumbel_output)

for i in range(len(ori_model.blocks)):
    ori_output = ori_model.blocks[i](ori_output)
    gumbel_output = gumbel_model.blocks[i](gumbel_output)
    print(f"{i}'s difference : ", ori_output - gumbel_output)

In [None]:
ori_output

In [None]:
gumbel_output

In [None]:
ori_output = ori_output.mean(3).mean(2)
gumbel_output = gumbel_output.mean(3).mean(2)
ori_output = ori_model.classifier(ori_output)
gumbel_output = gumbel_model.classifier(gumbel_output)
print("difference : ", (ori_output - gumbel_output))