In [10]:
import mcunet
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy

from mcunet.tinynas.nn.modules import MBInvertedConvLayer
from mcunet.tinynas.nn.networks import MobileInvertedResidualBlock
from mcunet.model_zoo import build_model


In [11]:
model, img_size, desc = build_model(net_id='mcunet-in4', pretrained=True)

In [48]:
count = 0
for n, m in model.named_modules():
    if isinstance(m, MobileInvertedResidualBlock):
        print(n)
        print(m)
        count += 4
        if count > 3: 
            break

blocks.0
MobileInvertedResidualBlock(
  (mobile_inverted_conv): MBInvertedConvLayer(
    (depth_conv): Sequential(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU6(inplace=True)
    )
    (point_linear): Sequential(
      (conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)


In [49]:
m.config

{'name': 'MobileInvertedResidualBlock',
 'mobile_inverted_conv': {'name': 'MBInvertedConvLayer',
  'in_channels': 32,
  'out_channels': 16,
  'kernel_size': 3,
  'stride': 1,
  'expand_ratio': 1,
  'mid_channels': None,
  'act_func': 'relu6',
  'use_se': False},
 'shortcut': None}

In [50]:
from mcunet.utils import MyModule, SEModule, build_activation, get_same_padding
from mcunet.tinynas.nn.modules import ZeroLayer, set_layer_from_config

class MobileGumbelInvertedResidualBlock(MyModule):

    def __init__(self, mobile_inverted_conv, shortcut):
        super(MobileGumbelInvertedResidualBlock, self).__init__()

        self.mobile_inverted_conv = mobile_inverted_conv
        self.shortcut = shortcut

    def forward(self, x, gumbel_idx=None):
        if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
            res = x
        elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer) and gumbel_idx == None:
            res = self.mobile_inverted_conv(x)
        elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer) and gumbel_idx != None:
            res = self.mobile_inverted_conv(x, gumbel_idx)
        elif self.shortcut is not None and gumbel_idx == None:
            res = self.mobile_inverted_conv(x) + self.shortcut(x)
        else:
            res = self.mobile_inverted_conv(x, gumbel_idx) + self.shortcut(x)
        return res

    @property
    def module_str(self):
        return '(%s, %s)' % (
            self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
            self.shortcut.module_str if self.shortcut is not None else None
        )

    @property
    def config(self):
        return {
            'name': MobileGumbelInvertedResidualBlock.__name__,
            'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
            'shortcut': self.shortcut.config if self.shortcut is not None else None,
        }

    @staticmethod
    def build_from_config(config):
        mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
        shortcut = set_layer_from_config(config['shortcut'])
        return MobileGumbelInvertedResidualBlock(mobile_inverted_conv, shortcut)

    @staticmethod
    def build_from_module(module):
        if isinstance(module, MobileGumbelInvertedResidualBlock):
            print("build from gumbel module")
            return module
        elif isinstance(module, MobileInvertedResidualBlock):
            print("build from normal MobileInvertedResidualBlock module")
            mobile_inverted_conv = module.mobile_inverted_conv
            shortcut = module.shortcut
            return MobileGumbelInvertedResidualBlock(module.mobile_inverted_conv, module.shortcut)

In [51]:
mm  = m.mobile_inverted_conv

In [52]:
mm.config

{'name': 'MBInvertedConvLayer',
 'in_channels': 32,
 'out_channels': 16,
 'kernel_size': 3,
 'stride': 1,
 'expand_ratio': 1,
 'mid_channels': None,
 'act_func': 'relu6',
 'use_se': False}

In [53]:
from collections import OrderedDict


class MBGumbelInvertedConvLayer(MyModule):
    global_kernel_size_list = [3,5,7]
    global_expand_ratio_list = [1,3,4,5,6]
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, expand_ratio=6, mid_channels=None, act_func='relu6', use_se=False, **kwargs):
        super(MBGumbelInvertedConvLayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.max_kernel_size = kernel_size
        self.kernel_size_list = []
        self.stride = stride
        self.max_expand_ratio = expand_ratio
        self.expand_ratio_list = []
        self.mid_channels = mid_channels
        self.act_func = act_func
        self.use_se = use_se
        
        for kernel in sorted(self.global_kernel_size_list):
            if kernel == self.max_kernel_size:
                self.kernel_size_list.append(kernel)
                break
            self.kernel_size_list.append(kernel)
        
        for expand in sorted(self.global_expand_ratio_list):
            if expand == self.max_expand_ratio:
                self.expand_ratio_list.append(expand)
                break
            self.expand_ratio_list.append(expand)

        if self.mid_channels is None:
            feature_dim = round(self.in_channels * self.max_expand_ratio)
        else:
            feature_dim = self.mid_channels

        if self.max_expand_ratio == 1:
            self.inverted_bottleneck = None
        else:
            self.inverted_bottleneck = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
                ('bn', nn.BatchNorm2d(feature_dim)),
                ('act', build_activation(self.act_func, inplace=True)),
            ]))

        pad = get_same_padding(self.max_kernel_size)
        depth_conv_modules = [
            ('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=feature_dim, bias=False)),
            ('bn', nn.BatchNorm2d(feature_dim)),
            ('act', build_activation(self.act_func, inplace=True))
        ]
        if self.use_se:
            depth_conv_modules.append(('se', SEModule(feature_dim)))
        self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))

        self.point_linear = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
            ('bn', nn.BatchNorm2d(out_channels)),
        ]))

    def forward(self, x):
        if self.inverted_bottleneck:
            x = self.inverted_bottleneck(x)
        x = self.depth_conv(x)
        x = self.point_linear(x)
        return x

    @property
    def module_str(self):
        if self.mid_channels is None:
            expand_ratio = self.max_expand_ratio
        else:
            expand_ratio = self.mid_channels // self.in_channels
        layer_str = '%dx%d_GumbelMBConv%d_%s' % (self.max_kernel_size, self.max_kernel_size, expand_ratio, self.act_func.upper())
        if self.use_se:
            layer_str = 'SE_' + layer_str
        layer_str += '_O%d' % self.out_channels
        return layer_str

    @property
    def config(self):
        return {
            'name': MBGumbelInvertedConvLayer.__name__,
            'in_channels': self.in_channels,
            'out_channels': self.out_channels,
            'kernel_size': self.max_kernel_size,
            'kernel_size_list': self.kernel_size_list,
            'stride': self.stride,
            'expand_ratio': self.max_expand_ratio,
            'expand_ratio_list': self.expand_ratio_list,
            'mid_channels': self.mid_channels,
            'act_func': self.act_func,
            'use_se': self.use_se,
        }

    @staticmethod
    def build_from_config(config):
        return MBGumbelInvertedConvLayer(**config)
    
    #@staticmethod
    #def build_from_module(module: MBInvertedConvLayer):
    #    mbgumbel = MBGumbelInvertedConvLayer.build_from_config(module.config)
    #    for n, m in module.named_parameters():
            


In [54]:
mbconv_test = MBGumbelInvertedConvLayer.build_from_config(m.mobile_inverted_conv.config)

In [55]:
mbconv_test.config

{'name': 'MBGumbelInvertedConvLayer',
 'in_channels': 32,
 'out_channels': 16,
 'kernel_size': 3,
 'kernel_size_list': [3],
 'stride': 1,
 'expand_ratio': 1,
 'expand_ratio_list': [1],
 'mid_channels': None,
 'act_func': 'relu6',
 'use_se': False}

In [56]:
original_mbconv_test_weight = copy.deepcopy(mbconv_test.depth_conv.conv.weight)
print(original_mbconv_test_weight)

Parameter containing:
tensor([[[[-0.2992, -0.1177,  0.3093],
          [-0.2990,  0.2471, -0.3293],
          [ 0.2489, -0.2503, -0.0731]]],


        [[[ 0.0895,  0.0370,  0.2446],
          [-0.0377, -0.1051, -0.0799],
          [-0.0628, -0.2726, -0.2245]]],


        [[[-0.1293,  0.3082, -0.1127],
          [-0.0178, -0.3019,  0.0067],
          [-0.0973, -0.0609,  0.0280]]],


        [[[ 0.2632, -0.0664,  0.0830],
          [ 0.1477, -0.2770,  0.0151],
          [-0.2820, -0.0734, -0.0909]]],


        [[[ 0.0200, -0.0904,  0.2705],
          [-0.1194,  0.1854, -0.2181],
          [-0.2449, -0.1923, -0.1018]]],


        [[[ 0.3010,  0.0608,  0.0703],
          [-0.2762,  0.2270, -0.0403],
          [ 0.2385,  0.2463,  0.0149]]],


        [[[ 0.2830,  0.2023,  0.0777],
          [ 0.1273, -0.0853,  0.0067],
          [-0.2430,  0.2285, -0.2602]]],


        [[[ 0.2286, -0.1018,  0.0032],
          [-0.3095,  0.2056,  0.2769],
          [-0.1081, -0.1240,  0.3281]]],


        [[

In [57]:
def get_deep_attr(obj, attrs):
    for attr in attrs.split("."):
        obj = getattr(obj, attr)
    return obj

def has_deep_attr(obj, attrs):
    try:
        get_deep_attr(obj, attrs)
        return True
    except AttributeError:
        return False

def set_deep_attr(obj, attrs, value):
    for attr in attrs.split(".")[:-1]:
        print(attr)
        obj = getattr(obj, attr)
    print(attrs.split(".")[-1])
    setattr(obj, attrs.split(".")[-1], value)
    
print(has_deep_attr(mbconv_test, "depth_conv.conv.weight"))
print(get_deep_attr(mbconv_test, "depth_conv.conv.weight"))


True
Parameter containing:
tensor([[[[-0.2992, -0.1177,  0.3093],
          [-0.2990,  0.2471, -0.3293],
          [ 0.2489, -0.2503, -0.0731]]],


        [[[ 0.0895,  0.0370,  0.2446],
          [-0.0377, -0.1051, -0.0799],
          [-0.0628, -0.2726, -0.2245]]],


        [[[-0.1293,  0.3082, -0.1127],
          [-0.0178, -0.3019,  0.0067],
          [-0.0973, -0.0609,  0.0280]]],


        [[[ 0.2632, -0.0664,  0.0830],
          [ 0.1477, -0.2770,  0.0151],
          [-0.2820, -0.0734, -0.0909]]],


        [[[ 0.0200, -0.0904,  0.2705],
          [-0.1194,  0.1854, -0.2181],
          [-0.2449, -0.1923, -0.1018]]],


        [[[ 0.3010,  0.0608,  0.0703],
          [-0.2762,  0.2270, -0.0403],
          [ 0.2385,  0.2463,  0.0149]]],


        [[[ 0.2830,  0.2023,  0.0777],
          [ 0.1273, -0.0853,  0.0067],
          [-0.2430,  0.2285, -0.2602]]],


        [[[ 0.2286, -0.1018,  0.0032],
          [-0.3095,  0.2056,  0.2769],
          [-0.1081, -0.1240,  0.3281]]],


     

In [58]:
print(m.mobile_inverted_conv.depth_conv.conv.weight)

Parameter containing:
tensor([[[[-1.4096e+00,  7.9066e-02,  1.3074e+00],
          [-3.6257e+00, -2.9244e-01,  4.0522e+00],
          [-3.8254e-02, -2.7275e-02,  2.4427e-02]]],


        [[[-2.6560e-06, -2.0290e-06, -3.1524e-06],
          [ 1.3314e-06,  1.7855e-05, -5.3570e-06],
          [-3.0966e-06,  1.5279e-06, -2.9197e-06]]],


        [[[-5.7868e-02,  4.9424e-02,  4.9366e-02],
          [ 1.3904e+00, -1.5971e+00, -9.1940e-02],
          [-5.8704e-02,  4.8760e-03,  5.4387e-02]]],


        [[[-6.5161e-01, -5.4830e-01,  8.4651e-01],
          [-3.8956e-01, -8.7885e-01, -2.9362e-02],
          [ 9.1019e-01, -1.1248e-01,  7.0372e-01]]],


        [[[-4.6209e-01, -3.3035e-01,  4.4910e-01],
          [-4.8235e-01, -3.9753e-01,  7.0083e-02],
          [ 4.2282e-01,  1.5623e-01,  1.8547e-01]]],


        [[[-5.8936e-02,  1.4643e-01, -5.6524e-02],
          [-4.7611e-02,  1.6341e+00, -7.4953e-02],
          [-3.5694e-03, -4.7760e-01,  1.6030e-02]]],


        [[[-5.1927e-01, -4.3182e-01,

In [59]:
for n, p in m.mobile_inverted_conv.named_parameters():
    if has_deep_attr(mbconv_test, n):
        print(n, p)
        set_deep_attr(mbconv_test, n, p)
        print('------------------')

depth_conv.conv.weight Parameter containing:
tensor([[[[-1.4096e+00,  7.9066e-02,  1.3074e+00],
          [-3.6257e+00, -2.9244e-01,  4.0522e+00],
          [-3.8254e-02, -2.7275e-02,  2.4427e-02]]],


        [[[-2.6560e-06, -2.0290e-06, -3.1524e-06],
          [ 1.3314e-06,  1.7855e-05, -5.3570e-06],
          [-3.0966e-06,  1.5279e-06, -2.9197e-06]]],


        [[[-5.7868e-02,  4.9424e-02,  4.9366e-02],
          [ 1.3904e+00, -1.5971e+00, -9.1940e-02],
          [-5.8704e-02,  4.8760e-03,  5.4387e-02]]],


        [[[-6.5161e-01, -5.4830e-01,  8.4651e-01],
          [-3.8956e-01, -8.7885e-01, -2.9362e-02],
          [ 9.1019e-01, -1.1248e-01,  7.0372e-01]]],


        [[[-4.6209e-01, -3.3035e-01,  4.4910e-01],
          [-4.8235e-01, -3.9753e-01,  7.0083e-02],
          [ 4.2282e-01,  1.5623e-01,  1.8547e-01]]],


        [[[-5.8936e-02,  1.4643e-01, -5.6524e-02],
          [-4.7611e-02,  1.6341e+00, -7.4953e-02],
          [-3.5694e-03, -4.7760e-01,  1.6030e-02]]],


        [[[-5

In [64]:
for n, p in m.mobile_inverted_conv.named_parameters():
    if has_deep_attr(mbconv_test, n):
        print(n)
        print(get_deep_attr(mbconv_test, n) - p)

depth_conv.conv.weight
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0.