In [1]:
import mcunet
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy

from mcunet.tinynas.nn.modules import MBInvertedConvLayer
from mcunet.tinynas.nn.networks import MobileInvertedResidualBlock
from mcunet.model_zoo import build_model


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)

In [74]:
count = 0
for n, m in model.named_modules():
    if isinstance(m, MobileInvertedResidualBlock):
        print(n)
        print(m)
        count += 2
        if count > 3: 
            break

blocks.0
MobileInvertedResidualBlock(
  (mobile_inverted_conv): MBInvertedConvLayer(
    (depth_conv): Sequential(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU6(inplace=True)
    )
    (point_linear): Sequential(
      (conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)
blocks.1
MobileInvertedResidualBlock(
  (mobile_inverted_conv): MBInvertedConvLayer(
    (inverted_bottleneck): Sequential(
      (conv): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU6(inplace=True)
    )
    (depth_conv): Sequential(
      (conv): Conv2d(48, 48, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)

In [75]:
m.config

{'name': 'MobileInvertedResidualBlock',
 'mobile_inverted_conv': {'name': 'MBInvertedConvLayer',
  'in_channels': 16,
  'out_channels': 24,
  'kernel_size': 7,
  'stride': 2,
  'expand_ratio': 3,
  'mid_channels': 48,
  'act_func': 'relu6',
  'use_se': False},
 'shortcut': None}

In [133]:
from mcunet.utils import MyModule, SEModule, build_activation, get_same_padding, sub_filter_start_end
from mcunet.tinynas.nn.modules import ZeroLayer, set_layer_from_config

class MobileGumbelInvertedResidualBlock(MyModule):

    def __init__(self, mobile_inverted_conv, shortcut):
        super(MobileGumbelInvertedResidualBlock, self).__init__()

        self.mobile_inverted_conv = mobile_inverted_conv
        self.shortcut = shortcut

    def forward(self, x, gumbel_idx=None):
        if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
            res = x
        elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer) and gumbel_idx == None:
            res = self.mobile_inverted_conv(x)
        elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer) and gumbel_idx != None:
            res = self.mobile_inverted_conv(x, gumbel_idx)
        elif self.shortcut is not None and gumbel_idx == None:
            res = self.mobile_inverted_conv(x) + self.shortcut(x)
        else:
            res = self.mobile_inverted_conv(x, gumbel_idx) + self.shortcut(x)
        return res

    @property
    def module_str(self):
        return '(%s, %s)' % (
            self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
            self.shortcut.module_str if self.shortcut is not None else None
        )

    @property
    def config(self):
        return {
            'name': MobileGumbelInvertedResidualBlock.__name__,
            'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
            'shortcut': self.shortcut.config if self.shortcut is not None else None,
        }

    @staticmethod
    def build_from_config(config):
        mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
        shortcut = set_layer_from_config(config['shortcut'])
        return MobileGumbelInvertedResidualBlock(mobile_inverted_conv, shortcut)

    @staticmethod
    def build_from_module(module):
        if isinstance(module, MobileGumbelInvertedResidualBlock):
            print("build from gumbel module")
            return module
        elif isinstance(module, MobileInvertedResidualBlock):
            print("build from normal MobileInvertedResidualBlock module")
            mobile_inverted_conv = module.mobile_inverted_conv
            shortcut = module.shortcut
            return MobileGumbelInvertedResidualBlock(module.mobile_inverted_conv, module.shortcut)

In [134]:
mm  = m.mobile_inverted_conv

In [135]:
mm.config

{'name': 'MBInvertedConvLayer',
 'in_channels': 16,
 'out_channels': 24,
 'kernel_size': 7,
 'stride': 2,
 'expand_ratio': 3,
 'mid_channels': 48,
 'act_func': 'relu6',
 'use_se': False}

In [168]:
from collections import OrderedDict


class MBGumbelInvertedConvLayer(MyModule):
    global_kernel_size_list = [3,5,7]
    global_expand_ratio_list = [3,4,5,6]
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, expand_ratio=6, mid_channels=None, act_func='relu6', use_se=False, **kwargs):
        super(MBGumbelInvertedConvLayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.max_kernel_size = kernel_size
        self.kernel_size_list = []
        self.stride = stride
        self.max_expand_ratio = expand_ratio
        self.expand_ratio_list = []
        self.mid_channels = mid_channels
        self.act_func = act_func
        self.use_se = use_se
        
        for kernel in sorted(self.global_kernel_size_list):
            if kernel == self.max_kernel_size:
                self.kernel_size_list.append(kernel)
                break
            self.kernel_size_list.append(kernel)
        
        self.kernel_size_list.reverse() # sorted in descending order
        
        if self.max_expand_ratio == 1:
            self.expand_ratio_list = []
        else:            
            for expand in sorted(self.global_expand_ratio_list):
                if expand == self.max_expand_ratio:
                    self.expand_ratio_list.append(expand)
                    break
                self.expand_ratio_list.append(expand)

        if self.mid_channels is None:
            feature_dim = round(self.in_channels * self.max_expand_ratio)
        else:
            feature_dim = self.mid_channels

        if self.max_expand_ratio == 1:
            self.inverted_bottleneck = None
        else:
            self.inverted_bottleneck = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
                ('bn', nn.BatchNorm2d(feature_dim)),
                ('act', build_activation(self.act_func, inplace=True)),
            ]))

        pad = get_same_padding(self.max_kernel_size)
        depth_conv_modules = [
            ('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=feature_dim, bias=False)),
            ('bn', nn.BatchNorm2d(feature_dim)),
            ('act', build_activation(self.act_func, inplace=True))
        ]
        if self.use_se:
            depth_conv_modules.append(('se', SEModule(feature_dim)))
        self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))

        self.point_linear = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
            ('bn', nn.BatchNorm2d(out_channels)),
        ]))

        self.kernel_transform_linear_list = nn.ModuleList()
        
        for i, kernel in enumerate(self.kernel_size_list[1:]):
            kernel_linear = nn.Linear(kernel*kernel, kernel*kernel)
            self.kernel_transform_linear_list.append(kernel_linear)

    def forward(self, x, gumbel=None):
        """
        gumbel: [batch_size, len(self.expand_ratio_list) + len(self.kernel_size_list)]
        """
        print("test1")
        if gumbel==None:
            if self.inverted_bottleneck:
                x = self.inverted_bottleneck(x)
            x = self.depth_conv(x)
            x = self.point_linear(x)
            return x
        else:
            print("test2")
            print(len(gumbel[0]), len(self.expand_ratio_list), len(self.kernel_size_list))
            if len(gumbel[0]) == len(self.expand_ratio_list) + len(self.kernel_size_list):
                print("test output")
                if self.inverted_bottleneck:
                    # 1. inverted bottleneck weights (max_expand_ratio)
                    expand_weight = self.inverted_bottleneck.conv.weight
                    expand_max_out = F.conv2d(x, expand_weight, stride=1, padding=0)
                    expand_max_out = self.inverted_bottleneck.bn(expand_max_out)
                    expand_max_out = self.inverted_bottleneck.act(expand_max_out)
                    expand_max_out *= gumbel[:, len(self.expand_ratio_list)].unsqueeze(1).unsqueeze(2).unsqueeze(3)
                    for i, expand_ratio in enumerate(self.expand_ratio_list[:-1]):
                        out = F.conv2d(input, weight)(x, expand_weight[:expand_ratio*self.in_channels, :, :, :], stride=1, padding=0)
                        out = self.inverted_bottleneck.bn(out)
                        out = self.inverted_bottleneck.act(out)
                        out *= gumbel[:, i].unsqueeze(1).unsqueeze(2).unsqueeze(3)
                        out = F.pad(out, [0, 0, 0, 0, 0, expand_max_out.size(1) - out.size(1)], mode='constant', value=0) # zero pad
                        expand_max_out += out
                    x = expand_max_out
                # 2. depthwise convolution weights (max_kernel_size)
                depth_weight = self.depth_conv.conv.weight
                pad = get_same_padding(self.max_kernel_size)
                print(3)
                kernel_max_out = F.conv2d(x, depth_weight, stride=self.stride, padding=pad, groups=x.size(1))
                kernel_max_out = self.depth_conv.bn(kernel_max_out)
                kernel_max_out = self.depth_conv.act(kernel_max_out)
                kernel_max_out *= gumbel[:, len(self.expand_ratio_list)].unsqueeze(1).unsqueeze(2).unsqueeze(3)
                for i, active_kernel_size in enumerate(self.kernel_size_list[1:]):
                    start, end = sub_filter_start_end(self.kernel_size_list[i], active_kernel_size)
                    print(start, end, active_kernel_size, self.kernel_size_list[i], depth_weight.shape)
                    kernel_weight = depth_weight[:, :, start:end, start:end].contiguous()
                    kernel_weight = kernel_weight.view(kernel_weight.size(0), kernel_weight.size(1), -1)
                    kernel_weight = self.kernel_transform_linear_list[i](kernel_weight)
                    kernel_weight = kernel_weight.view(kernel_weight.size(0), kernel_weight.size(1), active_kernel_size, active_kernel_size)
                    pad = get_same_padding(active_kernel_size)
                    kernel_out = F.conv2d(x, kernel_weight, stride=self.stride, padding=pad, groups=x.size(1))
                    kernel_out = self.depth_conv.bn(kernel_out)
                    kernel_out = self.depth_conv.act(kernel_out)
                    kernel_out *= gumbel[:, len(self.expand_ratio_list) + i + 1].unsqueeze(1).unsqueeze(2).unsqueeze(3)
                    kernel_max_out += kernel_out
                x = kernel_max_out
                if self.use_se:
                    x = self.depth_conv.se(x)
                # 3. pointwise convolution weights (out_channels)
                x = self.point_linear(x)
                return x
            else:
                assert False, "gumbel size is not match with expand_ratio_list and kernel_size_list"
    @property
    def module_str(self):
        if self.mid_channels is None:
            expand_ratio = self.max_expand_ratio
        else:
            expand_ratio = self.mid_channels // self.in_channels
        layer_str = '%dx%d_GumbelMBConv%d_%s' % (self.max_kernel_size, self.max_kernel_size, expand_ratio, self.act_func.upper())
        if self.use_se:
            layer_str = 'SE_' + layer_str
        layer_str += '_O%d' % self.out_channels
        return layer_str

    @property
    def config(self):
        return {
            'name': MBGumbelInvertedConvLayer.__name__,
            'in_channels': self.in_channels,
            'out_channels': self.out_channels,
            'kernel_size': self.max_kernel_size,
            'kernel_size_list': self.kernel_size_list,
            'stride': self.stride,
            'expand_ratio': self.max_expand_ratio,
            'expand_ratio_list': self.expand_ratio_list,
            'mid_channels': self.mid_channels,
            'act_func': self.act_func,
            'use_se': self.use_se,
        }

    @staticmethod
    def build_from_config(config):
        return MBGumbelInvertedConvLayer(**config)
    
    #@staticmethod
    #def build_from_module(module: MBInvertedConvLayer):
    #    mbgumbel = MBGumbelInvertedConvLayer.build_from_config(module.config)
    #    for n, m in module.named_parameters():
            


In [169]:
mbconv_test = MBGumbelInvertedConvLayer.build_from_config(m.mobile_inverted_conv.config)
mbconv_test.config

{'name': 'MBGumbelInvertedConvLayer',
 'in_channels': 16,
 'out_channels': 24,
 'kernel_size': 7,
 'kernel_size_list': [7, 5, 3],
 'stride': 2,
 'expand_ratio': 3,
 'expand_ratio_list': [3],
 'mid_channels': 48,
 'act_func': 'relu6',
 'use_se': False}

In [170]:
gumbel[:, 0].unsqueeze(1).unsqueeze(2).unsqueeze(3)

tensor([[[[1]]],


        [[[1]]]])

In [171]:
gumbel = torch.tensor([[1,1,0,0], [1,0,1,0]])
print(gumbel.shape)
mbconv_test.forward(torch.randn(2, 16, 32, 32), gumbel)

torch.Size([2, 4])
test1
test2
4 1 3
test output
3
1 6 5 7 torch.Size([48, 1, 7, 7])
1 4 3 5 torch.Size([48, 1, 7, 7])


tensor([[[[ 0.8567,  0.4870, -1.4085,  ...,  0.7508,  0.0068,  0.6479],
          [-0.6933, -1.9861, -1.0149,  ..., -0.2331, -0.9099, -0.3847],
          [-1.2974, -2.8775,  2.7149,  ..., -0.9148, -0.3412, -0.8784],
          ...,
          [ 0.2214, -0.9403, -0.2270,  ...,  0.0754, -0.6988, -0.7416],
          [ 0.8793, -0.6963,  0.1890,  ..., -1.1185, -1.9249, -1.0328],
          [-1.2504, -2.4076, -0.9345,  ...,  0.2322, -1.7815,  0.0396]],

         [[-0.1566,  0.1252,  1.1607,  ...,  1.8710,  0.9073, -3.0951],
          [-0.6777, -1.0121,  0.3968,  ...,  0.1152,  0.5914, -0.8969],
          [-0.5909, -1.8111,  0.1072,  ...,  0.7545, -0.3387,  1.2284],
          ...,
          [-0.4356, -1.3845,  2.0243,  ...,  0.3273, -0.3519, -2.6793],
          [ 0.3031, -1.0035,  1.7024,  ...,  0.5229, -0.5049, -0.0960],
          [-0.5559, -0.4275, -1.9258,  ..., -0.7857, -0.3518, -0.9294]],

         [[ 0.9074,  0.6539,  1.3948,  ..., -1.1132,  0.0380,  0.4552],
          [ 0.0634, -0.6252, -

In [84]:
g = torch.tensor([[1,0,0]])
len(g[0])

3

In [None]:
mbconv_test(torch.randn(1, 32, 32, 32), gumbel=[1, 0, 0, 1, 0, 0, 0])

In [29]:
original_mbconv_test_weight = copy.deepcopy(mbconv_test.depth_conv.conv.weight)
print(original_mbconv_test_weight)

Parameter containing:
tensor([[[[-0.1094, -0.2991,  0.1165],
          [-0.2913, -0.1030, -0.1255],
          [-0.0076, -0.2175,  0.2985]]],


        [[[ 0.1619,  0.2223, -0.1550],
          [ 0.3198,  0.1555,  0.2526],
          [-0.2230, -0.0804, -0.0267]]],


        [[[ 0.0777,  0.3209, -0.2567],
          [ 0.1512, -0.1042,  0.2795],
          [ 0.0665,  0.0769,  0.2850]]],


        [[[-0.3126, -0.2617, -0.1305],
          [-0.2550, -0.1393,  0.2465],
          [-0.0648,  0.0539, -0.1078]]],


        [[[-0.0091,  0.0819,  0.1442],
          [ 0.0080, -0.3168,  0.2750],
          [ 0.0956, -0.0929,  0.0606]]],


        [[[ 0.0563, -0.2193,  0.0763],
          [-0.2161, -0.0026,  0.1221],
          [ 0.1091,  0.1047, -0.0895]]],


        [[[-0.0575, -0.3265,  0.0334],
          [ 0.0240, -0.1833,  0.3118],
          [ 0.1552,  0.1193,  0.0709]]],


        [[[-0.0372, -0.1944,  0.0643],
          [-0.2934,  0.0974, -0.2598],
          [ 0.0029,  0.2455,  0.2021]]],


        [[

In [30]:
def get_deep_attr(obj, attrs):
    for attr in attrs.split("."):
        obj = getattr(obj, attr)
    return obj

def has_deep_attr(obj, attrs):
    try:
        get_deep_attr(obj, attrs)
        return True
    except AttributeError:
        return False

def set_deep_attr(obj, attrs, value):
    for attr in attrs.split(".")[:-1]:
        print(attr)
        obj = getattr(obj, attr)
    print(attrs.split(".")[-1])
    setattr(obj, attrs.split(".")[-1], value)
    
print(has_deep_attr(mbconv_test, "depth_conv.conv.weight"))
print(get_deep_attr(mbconv_test, "depth_conv.conv.weight"))


True
Parameter containing:
tensor([[[[-0.1094, -0.2991,  0.1165],
          [-0.2913, -0.1030, -0.1255],
          [-0.0076, -0.2175,  0.2985]]],


        [[[ 0.1619,  0.2223, -0.1550],
          [ 0.3198,  0.1555,  0.2526],
          [-0.2230, -0.0804, -0.0267]]],


        [[[ 0.0777,  0.3209, -0.2567],
          [ 0.1512, -0.1042,  0.2795],
          [ 0.0665,  0.0769,  0.2850]]],


        [[[-0.3126, -0.2617, -0.1305],
          [-0.2550, -0.1393,  0.2465],
          [-0.0648,  0.0539, -0.1078]]],


        [[[-0.0091,  0.0819,  0.1442],
          [ 0.0080, -0.3168,  0.2750],
          [ 0.0956, -0.0929,  0.0606]]],


        [[[ 0.0563, -0.2193,  0.0763],
          [-0.2161, -0.0026,  0.1221],
          [ 0.1091,  0.1047, -0.0895]]],


        [[[-0.0575, -0.3265,  0.0334],
          [ 0.0240, -0.1833,  0.3118],
          [ 0.1552,  0.1193,  0.0709]]],


        [[[-0.0372, -0.1944,  0.0643],
          [-0.2934,  0.0974, -0.2598],
          [ 0.0029,  0.2455,  0.2021]]],


     

In [31]:
print(m.mobile_inverted_conv.depth_conv.conv.weight)

Parameter containing:
tensor([[[[-1.4096e+00,  7.9066e-02,  1.3074e+00],
          [-3.6257e+00, -2.9244e-01,  4.0522e+00],
          [-3.8254e-02, -2.7275e-02,  2.4427e-02]]],


        [[[-2.6560e-06, -2.0290e-06, -3.1524e-06],
          [ 1.3314e-06,  1.7855e-05, -5.3570e-06],
          [-3.0966e-06,  1.5279e-06, -2.9197e-06]]],


        [[[-5.7868e-02,  4.9424e-02,  4.9366e-02],
          [ 1.3904e+00, -1.5971e+00, -9.1940e-02],
          [-5.8704e-02,  4.8760e-03,  5.4387e-02]]],


        [[[-6.5161e-01, -5.4830e-01,  8.4651e-01],
          [-3.8956e-01, -8.7885e-01, -2.9362e-02],
          [ 9.1019e-01, -1.1248e-01,  7.0372e-01]]],


        [[[-4.6209e-01, -3.3035e-01,  4.4910e-01],
          [-4.8235e-01, -3.9753e-01,  7.0083e-02],
          [ 4.2282e-01,  1.5623e-01,  1.8547e-01]]],


        [[[-5.8936e-02,  1.4643e-01, -5.6524e-02],
          [-4.7611e-02,  1.6341e+00, -7.4953e-02],
          [-3.5694e-03, -4.7760e-01,  1.6030e-02]]],


        [[[-5.1927e-01, -4.3182e-01,

In [32]:
for n, p in m.mobile_inverted_conv.named_parameters():
    if has_deep_attr(mbconv_test, n):
        print(n, p)
        set_deep_attr(mbconv_test, n, p)
        print('------------------')

depth_conv.conv.weight Parameter containing:
tensor([[[[-1.4096e+00,  7.9066e-02,  1.3074e+00],
          [-3.6257e+00, -2.9244e-01,  4.0522e+00],
          [-3.8254e-02, -2.7275e-02,  2.4427e-02]]],


        [[[-2.6560e-06, -2.0290e-06, -3.1524e-06],
          [ 1.3314e-06,  1.7855e-05, -5.3570e-06],
          [-3.0966e-06,  1.5279e-06, -2.9197e-06]]],


        [[[-5.7868e-02,  4.9424e-02,  4.9366e-02],
          [ 1.3904e+00, -1.5971e+00, -9.1940e-02],
          [-5.8704e-02,  4.8760e-03,  5.4387e-02]]],


        [[[-6.5161e-01, -5.4830e-01,  8.4651e-01],
          [-3.8956e-01, -8.7885e-01, -2.9362e-02],
          [ 9.1019e-01, -1.1248e-01,  7.0372e-01]]],


        [[[-4.6209e-01, -3.3035e-01,  4.4910e-01],
          [-4.8235e-01, -3.9753e-01,  7.0083e-02],
          [ 4.2282e-01,  1.5623e-01,  1.8547e-01]]],


        [[[-5.8936e-02,  1.4643e-01, -5.6524e-02],
          [-4.7611e-02,  1.6341e+00, -7.4953e-02],
          [-3.5694e-03, -4.7760e-01,  1.6030e-02]]],


        [[[-5

In [33]:
for n, p in m.mobile_inverted_conv.named_parameters():
    if has_deep_attr(mbconv_test, n):
        print(n)
        print(get_deep_attr(mbconv_test, n) - p)

depth_conv.conv.weight
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0.

In [34]:
mbconv_test.forward(torch.randn(1,32,16,16), gumbel=1)

use gumbel
