In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import sys

sys.path.append('./mcunet')

from mcunet.gumbel_module.gumbel_net import GumbelMCUNet
from mcunet.gumbel_module.gumbel_layer import MBGumbelInvertedConvLayer, MobileGumbelInvertedResidualBlock, count_conv_gumbel_flops
from mcunet.tinynas.nn.modules import MBInvertedConvLayer
from mcunet.tinynas.nn.networks import MobileInvertedResidualBlock
from mcunet.model_zoo import build_model

from mcunet.utils import MyModule, MyNetwork, SEModule, build_activation, get_same_padding, sub_filter_start_end, rm_bn_from_net, get_deep_attr, has_deep_attr
from mcunet.tinynas.nn.modules import ZeroLayer, set_layer_from_config

from fvcore.nn import FlopCountAnalysis

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ori_model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)
gubmel_config = {'global_expand_ratio_list':[1,3,4,5,6], 'global_kernel_size_list':[3,5,7], 'gumbel_feature_extract_block_idx':2}
gumbel_model = GumbelMCUNet.build_from_config(ori_model.config, gubmel_config)
gumbel_model.load_pretrained_mcunet_param(ori_model)

inputs = torch.randn(16, 3, 160, 160).cuda()
gumbel_model = gumbel_model.cuda()
output, gumbel_list = gumbel_model.forward(inputs.cuda())
#out = gumbel_model.forward_original(inputs)
#out2 = ori_model.forward(inputs)
#print((out - out2).sum())
gumbel_model.set_static_flops(inputs)

load pretrained mcumodel to gumbel net


Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::hardtanh_ encountered 1 time(s)
Unsupported operator aten::hardtanh_ encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobile_inverted_conv.depth_conv.bn, mobile_inverted_conv.point_linear.bn
Unsupported operator aten::hardtanh_ encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobile_inverted_conv.depth_conv.bn, mobile_inverted_

Success Log Static & Dynamic Flops : 367427584, 1670131200


In [3]:
ori_model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)

rm_bn_from_net(ori_model)
flops = FlopCountAnalysis(ori_model, torch.randn(1, 3, 160, 160))

In [19]:
for k, v in flops._analyze().counts.items():
    if k.endswith('mobile_inverted_conv'):
        if has_deep_attr(gumbel_model, k):
            m = get_deep_attr(gumbel_model, k)
            if isinstance(m, MBGumbelInvertedConvLayer):
                print("module : ", k)
                full_flops = m.compute_flops()
                print(f"fvcore vs custom flops : {v['conv']}, {full_flops}", v['conv']- full_flops)
                kernel_size_list, expand_ratio_list = m.kernel_size_list, m.expand_ratio_list
                if len(kernel_size_list) > 1 and len(expand_ratio_list) > 1:
                    print("max gumbel test")
                    mm = m
                    m.initialize_flops = False
                    gumbel = torch.zeros(len(expand_ratio_list)+len(kernel_size_list))
                    gumbel[len(expand_ratio_list)-1]=1
                    gumbel[len(expand_ratio_list)]=1
                    gumbel = gumbel.unsqueeze(0).to(m.kernel_size_list.device)
                    gumbel_flops = m.compute_gumbel_flops(gumbel)
                    print("fvcore vs custom flops : ", int(v['conv']- gumbel_flops))                    
                    for kk, vv in flops._analyze().counts.items():
                        if k in kk and '.conv' in kk:
                            print(kk, vv)
                    print(f"FLOPS\ninverted conv : {m.inverted_flops} dw conv : {m.dw_flops} pw conv : {m.pw_flops}")
                    print("expand ratio list : ", expand_ratio_list)
                    

                    for i in range(len(kernel_size_list)):
                        for j in range(len(expand_ratio_list)):
                            gumbel = torch.zeros(len(expand_ratio_list)+len(kernel_size_list))
                            gumbel[j]=1
                            gumbel[len(expand_ratio_list)+i]=1
                            gumbel = gumbel.unsqueeze(0).to(m.kernel_size_list.device)
                            gumbel_flops = m.compute_gumbel_flops(gumbel)
                            print("gumbel indices : ", gumbel)
                            print(f"fvcore {int(v['conv'])}  vs custom flops {int(gumbel_flops)} : ",  int(v['conv']- gumbel_flops))                    
                                                    
                print("=="*20)

module :  blocks.2.mobile_inverted_conv
fvcore vs custom flops : 10944000, 10944000.0 tensor(0.)
module :  blocks.3.mobile_inverted_conv
fvcore vs custom flops : 11212800, 11212800.0 tensor(0.)
max gumbel test
initialize FLOPs table
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
expand ratio list > 1 (3), kernel size list > 1 (2)
module config
input h/w : 40 / 40
stride : 1
inverted conv input shape : torch.Size([96, 24, 1, 1])
depth conv weight shape : torch.Size([96, 1, 5, 5])
point conv weight shape : torch.Size([24, 96, 1, 1])
max expand size : 4
max kernel size : 5
expand ratio check : True
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
inverted blocks flops : tensor([ 921600., 2764800., 3686400.])
dw convert flops :,  tensor([   0., 7776.])
dw blocks flops : tensor([3840000., 1382400.])
pw blocks flops : tensor([ 921600., 2764800., 3686400.])
gumbel idx :  tensor([[0., 0., 1., 1., 0.]])
==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*
expand_ratio_list: tensor([1, 3, 4]), kernel_

In [5]:
gumbel_model.set_static_flops(torch.randn(1,3,160,160).cuda())

Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::hardtanh_ encountered 1 time(s)
Unsupported operator aten::hardtanh_ encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobile_inverted_conv.depth_conv.bn, mobile_inverted_conv.point_linear.bn
Unsupported operator aten::hardtanh_ encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobile_inverted_conv.depth_conv.bn, mobile_inverted_

Unsupported operator aten::hardtanh_ encountered 2 time(s)
Unsupported operator aten::add encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobile_inverted_conv.depth_conv.bn, mobile_inverted_conv.inverted_bottleneck.bn, mobile_inverted_conv.kernel_transform_linear_list.0, mobile_inverted_conv.point_linear.bn, shortcut
Unsupported operator aten::hardtanh_ encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobi

Success Log Static & Dynamic Flops : 22964224, 104383200


In [6]:
g =  torch.arange(20).reshape(4,5)

g[:, 0:3]

tensor([[ 0,  1,  2],
        [ 5,  6,  7],
        [10, 11, 12],
        [15, 16, 17]])

In [7]:
total_flops = 0
for k, v in flops._analyze().counts.items():
    if k.endswith('mobile_inverted_conv'):
        if has_deep_attr(gumbel_model, k):
            m = get_deep_attr(gumbel_model, k)
            if isinstance(m, MBGumbelInvertedConvLayer):
                full_flops = m.compute_flops()
                print(f"fvcore vs custom flops : {v['conv']}, {full_flops}", v['conv']- full_flops)
                m.initialize_flops = False
                kernel_size_list, expand_ratio_list = m.kernel_size_list, m.expand_ratio_list
                total_flops += int(full_flops)
                if len(kernel_size_list) > 1 and len(expand_ratio_list) > 1:
                    print("max gumbel test")
                    for kk, vv in flops._analyze().counts.items():
                        if k in kk and '.conv' in kk:
                            print(kk, vv)
                    print(f"FLOPS\ninverted conv : {m.inverted_flops} \ndw conv : {m.dw_flops} \npw conv : {m.pw_flops}")
                    print("expand ratio list : ", expand_ratio_list)
                    
                    print("gumbel_max flop check")
                    
                    mm = m
                    gumbel = torch.zeros(len(expand_ratio_list)+len(kernel_size_list))
                    gumbel[len(expand_ratio_list)-1]=1
                    gumbel[len(expand_ratio_list)]=1
                    gumbel = gumbel.unsqueeze(0).to(m.kernel_size_list.device)
                    gumbel_flops = m.compute_gumbel_flops(gumbel)
                    
                    print("fvcore vs gumbel_max flops : ", int(v['conv']- gumbel_flops))
                    
                    for i in range(len(expand_ratio_list)):    
                        gumbel = torch.zeros(len(expand_ratio_list)+len(kernel_size_list))
                        gumbel[len(expand_ratio_list)]=1
                        gumbel[i]=1
                        gumbel = gumbel.unsqueeze(0).to(m.kernel_size_list.device)
                        gumbel_flops = m.compute_gumbel_flops(gumbel)
                        print(f"full flops vs custom flops : {full_flops}, {int(gumbel_flops)}, {int(full_flops - gumbel_flops)}")

                print("=="*20)

fvcore vs custom flops : 10944000, 10944000.0 0.0
fvcore vs custom flops : 11212800, 11212800.0 0.0
max gumbel test
blocks.3.mobile_inverted_conv.inverted_bottleneck.conv Counter({'conv': 3686400})
blocks.3.mobile_inverted_conv.depth_conv.conv Counter({'conv': 3840000})
blocks.3.mobile_inverted_conv.point_linear.conv Counter({'conv': 3686400})
FLOPS
inverted conv : 3686400.0 
dw conv : 3840000.0 
pw conv : 3686400.0
expand ratio list :  tensor([1, 3, 4], device='cuda:0')
gumbel_max flop check
initialize FLOPs table
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
expand ratio list > 1 (3), kernel size list > 1 (2)
module config
input h/w : 40 / 40
stride : 1
inverted conv input shape : torch.Size([96, 24, 1, 1])
depth conv weight shape : torch.Size([96, 1, 5, 5])
point conv weight shape : torch.Size([24, 96, 1, 1])
max expand size : 4
max kernel size : 5
expand ratio check : True
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
inverted blocks flops : tensor([ 921600., 2764800., 3686400.], device='cud

In [8]:
total_flops

104383200

In [9]:
in_c = 16
expand = 1
inputs = torch.randn(1, in_c, 32, 32)
conv1 = nn.Conv2d(in_c, in_c * expand, kernel_size=3, stride=1, padding=1, bias=False)
expand = 4
conv2 = nn.Conv2d(in_c, in_c * expand, kernel_size=3, stride=1, padding=1, bias=False)

flops1 = FlopCountAnalysis(conv1, inputs).total()
flops2 = FlopCountAnalysis(conv2, inputs).total()

print(flops1, flops2)

print(count_conv_gumbel_flops(conv1.weight.shape, inputs.size()[2], inputs.size()[3]))
print(count_conv_gumbel_flops(conv2.weight.shape, inputs.size()[2], inputs.size()[3]))

2359296 9437184
2359296.0
9437184.0


In [10]:
gumbel_model = GumbelMCUNet.build_from_config(ori_model.config, gubmel_config)
#rm_bn_from_net(gumbel_model)
gumbel_model_flops = FlopCountAnalysis(gumbel_model, torch.randn(1, 3, 160, 160))
gumbel_model_flops.by_module_and_operator()

Unsupported operator aten::add_ encountered 110 time(s)
Unsupported operator aten::hardtanh_ encountered 76 time(s)
Unsupported operator aten::empty_like encountered 26 time(s)
Unsupported operator aten::exponential_ encountered 26 time(s)
Unsupported operator aten::log encountered 26 time(s)
Unsupported operator aten::neg encountered 26 time(s)
Unsupported operator aten::add encountered 96 time(s)
Unsupported operator aten::div encountered 26 time(s)
Unsupported operator aten::softmax encountered 26 time(s)
Unsupported operator aten::scatter_ encountered 26 time(s)
Unsupported operator aten::sub encountered 68 time(s)
Unsupported operator aten::mul_ encountered 68 time(s)
Unsupported operator aten::mul encountered 125 time(s)
Unsupported operator aten::pad encountered 25 time(s)
Unsupported operator aten::mean encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forwar

{'': Counter({'conv': 172696800,
          'batch_norm': 22180000,
          'adaptive_avg_pool2d': 38400,
          'linear': 2023520}),
 'first_conv': Counter({'conv': 5529600, 'batch_norm': 1024000}),
 'first_conv.conv': Counter({'conv': 5529600}),
 'first_conv.bn': Counter({'batch_norm': 1024000}),
 'first_conv.act': Counter(),
 'blocks': Counter({'conv': 167167200,
          'batch_norm': 21156000,
          'linear': 1292896}),
 'blocks.0': Counter({'conv': 5120000, 'batch_norm': 1536000}),
 'blocks.0.mobile_inverted_conv': Counter({'conv': 5120000,
          'batch_norm': 1536000}),
 'blocks.0.mobile_inverted_conv.depth_conv': Counter({'conv': 1843200,
          'batch_norm': 1024000}),
 'blocks.0.mobile_inverted_conv.depth_conv.conv': Counter({'conv': 1843200}),
 'blocks.0.mobile_inverted_conv.depth_conv.bn': Counter({'batch_norm': 1024000}),
 'blocks.0.mobile_inverted_conv.depth_conv.act': Counter(),
 'blocks.0.mobile_inverted_conv.point_linear': Counter({'conv': 3276800,
    

In [11]:
ori_model.blocks[16].config

{'name': 'MobileInvertedResidualBlock',
 'mobile_inverted_conv': {'name': 'MBInvertedConvLayer',
  'in_channels': 192,
  'out_channels': 320,
  'kernel_size': 5,
  'stride': 1,
  'expand_ratio': 4,
  'mid_channels': 768,
  'act_func': 'relu6',
  'use_se': False},
 'shortcut': None}

In [12]:
gumbel_model.blocks[16]

MobileGumbelInvertedResidualBlock(
  (mobile_inverted_conv): MBGumbelInvertedConvLayer(
    (kernel_transform_linear_list): ModuleList(
      (0): Linear(in_features=9, out_features=9, bias=True)
    )
    (inverted_bottleneck): Sequential(
      (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU6(inplace=True)
    )
    (depth_conv): Sequential(
      (conv): Conv2d(768, 768, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=768, bias=False)
      (bn): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU6(inplace=True)
    )
    (point_linear): Sequential(
      (conv): Conv2d(768, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

In [13]:
for n, m in ori_model.named_modules():
    if hasattr(gumbel_model, n):
        print(n)
        print("ori model parameter :", sum([p.numel() for p in m.parameters()]))
        print("gumbel model parameter:", sum([p.numel() for p in getattr(gumbel_model, n).parameters()]))

first_conv
ori model parameter : 864
gumbel model parameter: 928
blocks
ori model parameter : 1389368
gumbel model parameter: 1415602
classifier
ori model parameter : 321000
gumbel model parameter: 321000


In [14]:
next(m.parameters())

Parameter containing:
tensor([[ 0.1297, -0.0050,  0.0143,  ..., -0.0158, -0.1057,  0.0024],
        [ 0.1476, -0.1173, -0.0844,  ..., -0.0407, -0.0013,  0.1079],
        [ 0.0353, -0.0408,  0.0923,  ..., -0.0025,  0.0013, -0.0519],
        ...,
        [-0.0136,  0.0784, -0.0716,  ..., -0.1794, -0.0974,  0.0618],
        [ 0.1510,  0.0757,  0.1213,  ...,  0.0973, -0.0941, -0.1362],
        [ 0.0133,  0.0498,  0.0714,  ..., -0.2346,  0.0590,  0.0044]],
       requires_grad=True)

In [15]:
gumbel_model.set_static_flops(inputs.cuda())
gumbel_model.compute_flops(inputs, gumbel_list)
gumbel_model.compute_flops(inputs, gumbel_list)

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 16, 32, 32] to have 3 channels, but got 16 channels instead

In [None]:
gumbel_model.compute_flops(inputs, gumbel_list)

In [None]:
torch.tensor([1,2,3,4])* torch.tensor(block.mobile_inverted_conv.expand_ratio_list)

In [None]:
[1,2,3,4] * [2,6,8,12]

In [None]:
for i, block in enumerate(gumbel_model.blocks):
    if isinstance(block, MobileGumbelInvertedResidualBlock):
        print(block.shortcut, block.mobile_inverted_conv.depth_conv.conv.stride)
        break

In [None]:
gumbel_model.forward(inputs)

gumbel_model.eval()
print(gumbel_model.training)
out_origin = gumbel_model.forward(inputs)
for i in range(10):
    print(out_origin - gumbel_model.forward(inputs))


In [None]:
class test_gumbel(nn.Module):
    def __init__(self):
        super().__init__()
        self.test_layer = nn.Linear(20, 4)
    
    def forward(self, x):
        gumbel_input = self.test_layer(x)
        if self.training:
            gumbel_out = F.gumbel_softmax(gumbel_input, tau=1, hard=True, eps=1e-10, dim=-1)
        else:
            index = gumbel_input.max(dim=-1, keepdim=True)[1]
            gumbel_out = torch.zeros_like(gumbel_input, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        
        return gumbel_out
inputs = torch.randn(5, 20)
gumbel = test_gumbel()

# train
gumbel.train()
for i in range(5):
    out = gumbel(inputs)
    print(f"{i} iter -> ", out)

# test
gumbel.eval()
for i in range(5):
    out = gumbel(inputs)
    print(f"{i} iter -> ", out)


In [None]:
from torchprofile import profile_macs

total_mac = profile_macs(gumbel_model.cuda(), torch.randn(2, 3,160, 160).cuda())
print(total_mac)

In [None]:
from torchprofile.utils.flatten import Flatten
import warnings
with warnings.catch_warnings(record=True):
    graph, _ = torch.jit._get_trace_graph(Flatten(gumbel_model.cuda()), torch.randn(2,3,160,160).cuda(), None)

In [None]:
variables = dict()
for x in graph.nodes():
    for v in list(x.inputs()):
        if 'tensor' in v.type().kind().lower():
            print(v.debugName(), v.type().scalarType(), v.type().sizes())

In [None]:
for x in graph.nodes():
    if 'mul_' in x.kind().lower():
        print(x)

In [None]:
ori_model_size = sum([p.numel() for p in ori_model.parameters()]) * 4 / 2**20
gumbel_model_size = sum([p.numel() for p in gumbel_model.parameters()]) * 4 / 2**20
print("Ori model size : %.1f MB" % ori_model_size)
print("Gumbel model size : %.1f MB" % gumbel_model_size)

from torchinfo import summary

summary(gumbel_model, input_size=(4, 3, 160, 160), col_width=16, col_names=['kernel_size', 'output_size', 'num_params', 'mult_adds', 'params_percent'], depth=2)


In [None]:
print("before forward grad : ", gumbel_model.gumbel_fc1.weight.grad)
out = gumbel_model(torch.randn(32, 3, 160, 160))

out.sum().backward()

print("after forward grad : \n", gumbel_model.gumbel_fc1.weight.grad)

In [None]:
for n, p in net.named_parameters():
    if has_deep_attr(model, n):
        print(n)

In [None]:
model

In [None]:
mbconv_test = MBGumbelInvertedConvLayer.build_from_config(m.mobile_inverted_conv.config)
mbconv_test.config

In [None]:
inputs = torch.randn(2, 16, 32, 32)
gumbel_inputs = torch.randn(2, 4, 8, 8)
gumbel_inputs.requires_grad = True
gumbel_layer = nn.Linear(4*8*8, 5)
gumbel_output = gumbel_layer(gumbel_inputs.view(2, -1))
gumbel_index = F.gumbel_softmax(gumbel_output, tau=1, hard=True)
print(gumbel_index)
out = mbconv_test.forward(torch.randn(2, 16, 32, 32))

In [None]:
inputs = torch.randn(2, 16, 32, 32)
gumbel_inputs = torch.randn(2, 4, 8, 8)
gumbel_inputs.requires_grad = True
gumbel_layer = nn.Linear(4*8*8, 5)
gumbel_output = gumbel_layer(gumbel_inputs.view(2, -1))
gumbel_index = F.gumbel_softmax(gumbel_output, tau=1, hard=True)
print(gumbel_index)
out = mbconv_test.forward(torch.randn(2, 16, 32, 32), gumbel_index)
out.sum().backward()

In [None]:
gumbel_layer.weight.grad

In [None]:
original_mbconv_test_weight = copy.deepcopy(mbconv_test.depth_conv.conv.weight)
print(original_mbconv_test_weight)

In [None]:
print(m.mobile_inverted_conv.depth_conv.conv.weight)

In [None]:
for n, p in m.mobile_inverted_conv.named_parameters():
    if has_deep_attr(mbconv_test, n):
        print(n, p)
        set_deep_attr(mbconv_test, n, p)
        print('------------------')

In [None]:
for n, p in m.mobile_inverted_conv.named_parameters():
    if has_deep_attr(mbconv_test, n):
        print(n)
        print(get_deep_attr(mbconv_test, n) - p)

In [None]:
mbconv_test.forward(torch.randn(1,32,16,16), gumbel=1)

In [None]:
bn_layer = nn.BatchNorm2d(16)

In [None]:
x = torch.randn(1, 12, 32, 32)

In [None]:
feature_dim = 12
out = F.batch_norm(x, bn_layer.running_mean[:feature_dim], bn_layer.running_var[:feature_dim], bn_layer.weight[:feature_dim], bn_layer.bias[:feature_dim])

In [None]:
out.sum().backward()

In [None]:
bn_layer.weight.grad

In [None]:
model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)

backup_model = copy.deepcopy(model)
model_copy = build_model(net_id='mcunet-in4', pretrained=False)[0]

for (n1, p1), (n2, p2) in zip(backup_model.named_parameters(), model_copy.named_parameters()):
    if n1 == n2:
        print((p1 - p2).sum())

In [None]:
for n, p in model.named_parameters():
    if has_deep_attr(model_copy, n):
        print(n)
        set_deep_attr(model_copy, n, p)

In [None]:
for (n1, p1), (n2, p2) in zip(backup_model.named_parameters(), model_copy.named_parameters()):
    if n1 == n2:
        print((p1-p2).sum())