In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_same_padding(kernel_size):
    if isinstance(kernel_size, tuple):
        assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
        p1 = get_same_padding(kernel_size[0])
        p2 = get_same_padding(kernel_size[1])
        return p1, p2
    assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
    assert kernel_size % 2 > 0, 'kernel size should be odd number'
    return kernel_size // 2


def sub_filter_start_end(kernel_size, sub_kernel_size):
    center = kernel_size // 2
    dev = sub_kernel_size // 2
    start, end = center - dev, center + dev + 1
    assert end - start == sub_kernel_size
    return start, end

class DynamicConv2d(nn.Module):
    KERNEL_TRANSFORM_MODE = 1  # None or 1

    def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
        super(DynamicConv2d, self).__init__()

        self.max_in_channels = max_in_channels
        self.kernel_size_list = kernel_size_list
        self.stride = stride
        self.dilation = dilation

        self.conv = nn.Conv2d(
            self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
            groups=self.max_in_channels, bias=False,
        )

        self._ks_set = list(set(self.kernel_size_list))
        self._ks_set.sort()  # e.g., [3, 5, 7]
        if self.KERNEL_TRANSFORM_MODE is not None:
            # register scaling parameters
            # 7to5_matrix, 5to3_matrix
            scale_params = {}
            for i in range(len(self._ks_set) - 1):
                ks_small = self._ks_set[i]
                ks_larger = self._ks_set[i + 1]
                param_name = '%dto%d' % (ks_larger, ks_small)
                scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
            for name, param in scale_params.items():
                self.register_parameter(name, param)

        self.active_kernel_size = max(self.kernel_size_list)

    def get_active_filter(self, in_channel, kernel_size):
        out_channel = in_channel
        max_kernel_size = max(self.kernel_size_list)

        start, end = sub_filter_start_end(max_kernel_size, kernel_size)
        filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
        if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
            start_filter = self.conv.weight[:out_channel, :in_channel, :, :]  # start with max kernel
            for i in range(len(self._ks_set) - 1, 0, -1):
                src_ks = self._ks_set[i]
                if src_ks <= kernel_size:
                    break
                target_ks = self._ks_set[i - 1]
                start, end = sub_filter_start_end(src_ks, target_ks)
                _input_filter = start_filter[:, :, start:end, start:end]
                _input_filter = _input_filter.contiguous()
                _input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
                _input_filter = _input_filter.view(-1, _input_filter.size(2))
                _input_filter = F.linear(
                    _input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
                )
                _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
                _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
                start_filter = _input_filter
            filters = start_filter
        return filters

    def forward(self, x, kernel_size=None):
        if kernel_size is None:
            kernel_size = self.active_kernel_size
        in_channel = x.size(1)

        filters = self.get_active_filter(in_channel, kernel_size).contiguous()

        padding = get_same_padding(kernel_size)
        y = F.conv2d(
            x, filters, None, self.stride, padding, self.dilation, in_channel
        )
        return y


## 기존 Subnet select 방식은 gumbel_softmax의 gradient에 영향을 주지 못함

In [3]:
ks_list = [3,5,7]

dconv = DynamicConv2d(64, ks_list, stride=1, dilation=1)

gumbel_input = torch.randn(1, len(ks_list), requires_grad=True)

print("gumbel input: ", gumbel_input)
print("gumbel grad: ", gumbel_input.grad)
hard_gumbel = F.gumbel_softmax(gumbel_input, tau=1, hard=True)
print(hard_gumbel)

out = dconv.forward(torch.randn(1, 64, 32, 32), kernel_size=ks_list[torch.argmax(hard_gumbel)])
out.sum().backward()
print("dconv.weight.grad: ", dconv.conv.weight.grad[0,0])
print("dconv.convert.grad: ", dconv.get_parameter('7to5_matrix').shape, dconv.get_parameter('7to5_matrix').grad)
print("gumbel grad: ", gumbel_input.grad)

gumbel input:  tensor([[0.3260, 0.3021, 0.7876]], requires_grad=True)
gumbel grad:  None
tensor([[0., 0., 1.]], grad_fn=<AddBackward0>)
dconv.weight.grad:  tensor([[-39.0841, -38.3947, -35.4702, -32.7624, -27.6951, -30.1235, -19.3788],
        [-47.4388, -45.1476, -42.0656, -38.7817, -34.3943, -34.8257, -24.2159],
        [-46.3515, -44.9403, -43.1118, -40.8910, -36.9794, -38.1342, -27.0899],
        [-51.3097, -48.8564, -45.5509, -43.1278, -37.4699, -37.4104, -26.0948],
        [-43.1267, -42.3219, -38.0915, -35.4915, -32.3780, -32.4965, -22.2510],
        [-37.7014, -37.8889, -34.3843, -32.3811, -28.6648, -28.7608, -18.9653],
        [-35.2465, -36.4609, -34.5478, -31.2372, -27.7822, -27.5038, -18.4004]])
dconv.convert.grad:  torch.Size([25, 25]) None
gumbel grad:  None


In [4]:
gumbel_input = torch.randn(len(ks_list), requires_grad=True)
print(gumbel_input)
test_weight = torch.randn(len(ks_list), 64)
hard_gumbel = F.gumbel_softmax(gumbel_input, tau=1, hard=True)
print(hard_gumbel)
print(hard_gumbel.shape)
out = hard_gumbel.unsqueeze(1) * test_weight
print(out)
out.sum().backward()
print(gumbel_input.grad)

tensor([-0.8234, -0.3526, -0.5259], requires_grad=True)
tensor([0., 1., 0.], grad_fn=<AddBackward0>)
torch.Size([3])
tensor([[ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000, -0.0000,  0.0000,
         -0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000, -0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000, -0.0000,
          0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000,  0.0000,  0.0000,
         -0.0000, -0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000,
          0.0000,  0.0000, -0.0000, -0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
        [ 0.5743, -2.1588,  1.8553,  0.0300, -0.2108, -0.5413,  0.3284, -0.0719,
          1.2868, -1.3926, -0.1215,  0.5518, -0.6554, -0.1868, -0.5725, -1.9607,
          1.0544,  0.9271, -1.1055, -0.1890,  1.3631,  0.1581,  0.6159, 

In [5]:
for n, p in dconv.named_parameters():
    print(n)

5to3_matrix
7to5_matrix
conv.weight


In [6]:
class DynamicGumbelConv2d(nn.Module):
    KERNEL_TRANSFORM_MODE = None  # None or 1

    def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
        super(DynamicGumbelConv2d, self).__init__()

        self.max_in_channels = max_in_channels
        self.kernel_size_list = kernel_size_list
        self.stride = stride
        self.dilation = dilation

        self.conv = nn.Conv2d(
            self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
            groups=self.max_in_channels, bias=False,
        )

        self._ks_set = list(set(self.kernel_size_list))
        self._ks_set.sort()  # e.g., [3, 5, 7]
        if self.KERNEL_TRANSFORM_MODE is not None:
            # register scaling parameters
            # 7to5_matrix, 5to3_matrix
            scale_params = {}
            for i in range(len(self._ks_set) - 1):
                ks_small = self._ks_set[i]
                ks_larger = self._ks_set[i + 1]
                param_name = '%dto%d' % (ks_larger, ks_small)
                scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
            for name, param in scale_params.items():
                self.register_parameter(name, param)

        self.active_kernel_size = max(self.kernel_size_list)

    def get_active_filter(self, in_channel, kernel_size):
        out_channel = in_channel
        max_kernel_size = max(self.kernel_size_list)

        start, end = sub_filter_start_end(max_kernel_size, kernel_size)
        filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
        if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
            start_filter = self.conv.weight[:out_channel, :in_channel, :, :]  # start with max kernel
            for i in range(len(self._ks_set) - 1, 0, -1):
                src_ks = self._ks_set[i]
                if src_ks <= kernel_size:
                    break
                target_ks = self._ks_set[i - 1]
                start, end = sub_filter_start_end(src_ks, target_ks)
                _input_filter = start_filter[:, :, start:end, start:end]
                _input_filter = _input_filter.contiguous()
                _input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
                _input_filter = _input_filter.view(-1, _input_filter.size(2))
                _input_filter = F.linear(
                    _input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
                )
                _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
                _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
                start_filter = _input_filter
            filters = start_filter
        return filters

    def forward(self, x, kernel_size=None):
        if kernel_size is None:
            kernel_size = self.active_kernel_size
        in_channel = x.size(1)

        filters = self.get_active_filter(in_channel, kernel_size).contiguous()

        padding = get_same_padding(kernel_size)
        y = F.conv2d(
            x, filters, None, self.stride, padding, self.dilation, in_channel
        )
        return y


In [7]:
g = torch.randn(4, 20)

In [8]:
empty_g = torch.empty_like(g)
print(empty_g)
gumbel = empty_g.exponential_().log()
out = g.log_softmax(dim=1) + gumbel
print(out)
print(torch.softmax(out, dim=1))


tensor([[5.4480e-38, 4.5860e-41, 4.9810e-38, 4.5860e-41, 5.6306e-38, 4.5860e-41,
         4.9810e-38, 4.5860e-41, 5.7646e-38, 4.5860e-41, 5.7680e-38, 4.5860e-41,
         5.6462e-38, 4.5860e-41, 5.7679e-38, 4.5860e-41, 5.4165e-38, 4.5860e-41,
         5.8494e-38, 4.5860e-41],
        [5.6407e-38, 4.5860e-41, 5.8585e-38, 4.5860e-41, 5.9143e-38, 4.5860e-41,
         5.7646e-38, 4.5860e-41, 4.6185e-38, 4.5860e-41, 5.9194e-38, 4.5860e-41,
         5.3741e-38, 4.5860e-41, 5.9150e-38, 4.5860e-41, 5.5330e-38, 4.5860e-41,
         5.6407e-38, 4.5860e-41],
        [4.9811e-38, 4.5860e-41, 5.4480e-38, 4.5860e-41, 5.8696e-38, 4.5860e-41,
         5.6912e-38, 4.5860e-41, 5.9143e-38, 4.5860e-41, 5.6462e-38, 4.5860e-41,
         5.9150e-38, 4.5860e-41, 5.6205e-38, 4.5860e-41, 5.7646e-38, 4.5860e-41,
         4.9811e-38, 4.5860e-41],
        [5.6407e-38, 4.5860e-41, 5.8585e-38, 4.5860e-41, 5.9143e-38, 4.5860e-41,
         5.7646e-38, 4.5860e-41, 4.6185e-38, 4.5860e-41, 5.9194e-38, 4.5860e-41,
       

In [9]:
n = 4
x = torch.randn((2, 10, 4, 4), requires_grad=True)
bit_mask_module = nn.Linear(10*4*4, n) # 10 channel
conv = nn.Conv2d(10, n, 3, 1, 0)

In [10]:
bit_mask = F.gumbel_softmax(bit_mask_module(x.view(x.size(0), -1)).view(-1, n).contiguous(), tau=1, hard=True)
bit_mask = bit_mask.view(-1, n)
bit_mask_l = bit_mask
print(bit_mask)
#conv_out_list = [F.conv2d(x, conv.weight, None, 1, 1, 1, 0) * bit_mask[:, i:i+1] for i in range(n)]

tensor([[0., 0., 1., 0.],
        [1., 0., 0., 0.]], grad_fn=<ViewBackward0>)


In [11]:
bit_mask_module.weight.grad

In [12]:
bit_mask.sum().backward()
print(bit_mask.grad)
print(x.grad)

None
tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],

  print(bit_mask.grad)


In [13]:
out_list = []
for i in range(n):
    print(bit_mask[:, i:i+1])
    out = F.conv2d(x, conv.weight, None, 1, 0, 1, 1)[:,i,:,:] * bit_mask[:, i].unsqueeze(1).unsqueeze(2).unsqueeze(3)
    print(out)
    out_list.append(out)

total_out = torch.stack(out_list, dim=0)
print(total_out)

tensor([[0.],
        [1.]], grad_fn=<SliceBackward0>)
tensor([[[[-0.0000, -0.0000],
          [-0.0000,  0.0000]],

         [[ 0.0000, -0.0000],
          [-0.0000,  0.0000]]],


        [[[-0.0854, -0.1723],
          [-1.4785,  0.3930]],

         [[ 0.3277, -0.0673],
          [-0.6838,  0.0519]]]], grad_fn=<MulBackward0>)
tensor([[0.],
        [0.]], grad_fn=<SliceBackward0>)
tensor([[[[-0., 0.],
          [0., -0.]],

         [[-0., -0.],
          [0., 0.]]],


        [[[-0., 0.],
          [0., -0.]],

         [[-0., -0.],
          [0., 0.]]]], grad_fn=<MulBackward0>)
tensor([[1.],
        [0.]], grad_fn=<SliceBackward0>)
tensor([[[[-0.6014, -0.0896],
          [ 0.4359,  0.4022]],

         [[ 0.6617,  0.1109],
          [-0.3586,  0.8877]]],


        [[[-0.0000, -0.0000],
          [ 0.0000,  0.0000]],

         [[ 0.0000,  0.0000],
          [-0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
tensor([[0.],
        [0.]], grad_fn=<SliceBackward0>)
tensor([[[[0., 0.],
       

In [14]:
loss = total_out.sum()
print(loss)

tensor(-0.2658, grad_fn=<SumBackward0>)


In [16]:
bit_mask.grad

  bit_mask.grad


In [17]:
s = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False, groups=64)

In [18]:
s.weight.shape

torch.Size([64, 1, 3, 3])

In [35]:
class DynamicSeparableConv2d(nn.Module):
    KERNEL_TRANSFORM_MODE = 1  # None or 1

    def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
        super(DynamicSeparableConv2d, self).__init__()

        self.max_in_channels = max_in_channels
        self.kernel_size_list = kernel_size_list
        self.stride = stride
        self.dilation = dilation

        self.conv = nn.Conv2d(
            self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
            groups=self.max_in_channels, bias=False,
        )

        self._ks_set = list(set(self.kernel_size_list))
        self._ks_set.sort()  # e.g., [3, 5, 7]
        if self.KERNEL_TRANSFORM_MODE is not None:
            # register scaling parameters
            # 7to5_matrix, 5to3_matrix
            scale_params = {}
            for i in range(len(self._ks_set) - 1):
                ks_small = self._ks_set[i]
                ks_larger = self._ks_set[i + 1]
                param_name = '%dto%d' % (ks_larger, ks_small)
                scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
            for name, param in scale_params.items():
                self.register_parameter(name, param)

        self.active_kernel_size = max(self.kernel_size_list)

    def get_active_filter(self, in_channel, kernel_size):
        out_channel = in_channel
        max_kernel_size = max(self.kernel_size_list)

        start, end = sub_filter_start_end(max_kernel_size, kernel_size)
        filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
        if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
            start_filter = self.conv.weight[:out_channel, :in_channel, :, :]  # start with max kernel
            for i in range(len(self._ks_set) - 1, 0, -1):
                src_ks = self._ks_set[i]
                if src_ks <= kernel_size:
                    break
                target_ks = self._ks_set[i - 1]
                start, end = sub_filter_start_end(src_ks, target_ks)
                _input_filter = start_filter[:, :, start:end, start:end]
                _input_filter = _input_filter.contiguous()
                _input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
                _input_filter = _input_filter.view(-1, _input_filter.size(2))
                _input_filter = F.linear(
                    _input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
                )
                _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
                _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
                start_filter = _input_filter
            filters = start_filter
        return filters

    def forward(self, x, gumbel_input=None):
        
        #if kernel_size is None:
        #    kernel_size = self.active_kernel_size
        in_channel = x.size(1)
        
        out_list = []
        for i in range(len(self.kernel_size_list)):
            kernel_size = self.kernel_size_list[i]
            filters = self.get_active_filter(in_channel, kernel_size).contiguous()
            padding = get_same_padding(kernel_size)
            y = F.conv2d(
                x, filters, None, self.stride, padding, self.dilation, in_channel
            )
            out = y * gumbel_input[:, i].unsqueeze(1).unsqueeze(2).unsqueeze(3)
            print(gumbel_input[:, i])
            print(out)
            out_list.append(out)
        y = torch.sum(torch.stack(out_list, dim=0), dim=0)
        return y


In [36]:
in_c = 4
k = 4
module = DynamicSeparableConv2d(in_c, [3,5,7])
inputs = torch.randn(1, in_c, k, k)
gumbel_select = nn.Linear(k*k*in_c, 3) # self.kernel_list

gumbel_input = gumbel_select(inputs.view(inputs.size(0), -1))
print(gumbel_input.size())
bit_mask = F.gumbel_softmax(gumbel_input, tau=1, hard=True, dim=-1)
print(bit_mask)

torch.Size([1, 3])
tensor([[1., 0., 0.]], grad_fn=<AddBackward0>)


In [43]:
out = module(inputs, bit_mask)
out.sum().backward()

tensor([1.], grad_fn=<SelectBackward0>)
tensor([[[[ 0.1409,  0.0017, -0.1980, -0.2792],
          [ 0.1436, -0.1170,  0.2453,  0.0059],
          [-0.3516,  0.2223, -0.0338,  0.1091],
          [-0.1531,  0.0832, -0.2318,  0.1527]],

         [[ 0.1355, -0.0392,  0.0702,  0.1168],
          [-0.0683,  0.1784,  0.3267, -0.1909],
          [-0.1062,  0.1773, -0.0282, -0.0032],
          [ 0.2182,  0.0144, -0.1456,  0.0809]],

         [[-0.0575, -0.2015,  0.0263, -0.0784],
          [ 0.1602,  0.1498,  0.0932, -0.0224],
          [ 0.0466,  0.0282,  0.0581,  0.0196],
          [-0.0937,  0.0789,  0.0055,  0.0410]],

         [[ 0.0316, -0.2561, -0.4396, -0.2923],
          [ 0.0712, -0.0032, -0.0657, -0.3710],
          [-0.2168, -0.3756, -0.4276,  0.4448],
          [ 0.0349, -0.1824,  0.0534,  0.0140]]]], grad_fn=<MulBackward0>)
tensor([0.], grad_fn=<SelectBackward0>)
tensor([[[[0., 0., -0., -0.],
          [0., -0., 0., -0.],
          [-0., 0., -0., 0.],
          [-0., 0., -0., 0.]]

In [44]:
module

DynamicSeparableConv2d(
  (conv): Conv2d(4, 4, kernel_size=(7, 7), stride=(1, 1), groups=4, bias=False)
)

In [45]:
gumbel_select.weight.grad

tensor([[-3.8496e-03,  6.7819e-03,  2.9089e-03, -6.6991e-03,  3.7747e-03,
         -7.8763e-03,  1.0054e-02,  3.2461e-03,  4.3712e-03, -8.5085e-03,
          5.2003e-03, -1.8398e-03,  1.0541e-02, -2.6015e-03, -4.7269e-03,
          2.3385e-05, -1.8948e-03,  4.7095e-03,  4.5827e-03,  5.8954e-03,
          5.8511e-03, -7.8198e-04, -3.6422e-05,  6.3703e-03, -6.4089e-03,
          5.8479e-03,  6.4663e-03, -1.1468e-02, -7.1738e-03,  7.1592e-03,
          2.2607e-03, -3.4356e-03,  6.4294e-03,  7.5568e-03,  4.6905e-03,
         -8.2747e-03,  7.3102e-03, -2.1480e-03,  4.4802e-03,  1.4163e-03,
         -1.4185e-02, -7.3711e-03, -5.7679e-03, -3.7398e-03,  3.4935e-03,
          3.5776e-03, -1.9053e-03,  1.2204e-03, -1.3589e-03, -1.0452e-02,
         -1.1415e-03, -6.8882e-03,  2.2557e-03,  1.5920e-02,  3.4395e-03,
          1.6593e-02, -7.3449e-03, -4.2015e-03,  1.8793e-02,  4.7848e-03,
          5.5519e-03, -2.2328e-03,  9.7191e-04, -5.8595e-03],
        [-9.9698e-03,  1.7564e-02,  7.5335e-03, -1

In [46]:
for n, m in module.named_parameters():
    print(n)
    print(m.grad)

5to3_matrix
tensor([[ 0.2090, -0.1782,  0.2345,  0.3628,  0.0185,  0.2426, -0.6996, -0.2479,
         -0.5354],
        [ 0.1860, -0.0771, -0.0367,  0.7314, -0.0662,  0.0943, -0.7062, -0.3619,
         -0.4093],
        [ 0.1132,  0.0156, -0.1465,  0.9133, -0.1156, -0.0082, -0.7651, -0.3789,
         -0.4070],
        [ 0.2113, -0.1941,  0.2523,  0.4045,  0.0205,  0.2236, -0.9038, -0.3480,
         -0.6915],
        [ 0.1801, -0.1052,  0.0112,  0.6236, -0.0426,  0.0857, -0.7950, -0.4085,
         -0.5068],
        [ 0.0237,  0.0977, -0.0388,  0.9219, -0.1299,  0.0492, -0.6579, -0.1234,
         -0.3739],
        [ 0.2191, -0.1045, -0.1405,  0.6949, -0.0622,  0.0299, -0.6748, -0.4879,
         -0.3502],
        [ 0.1782, -0.0413, -0.4880,  0.8495, -0.1154, -0.2612, -0.7898, -0.8179,
         -0.3098],
        [-0.0906,  0.2495, -0.5967,  1.1062, -0.2216, -0.4134, -0.6181, -0.4926,
         -0.1532]])
7to5_matrix
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
 

In [2]:
from mcunet.mcunet.tinynas.nn.modules import MBInvertedConvLayer, ConvLayer, LinearLayer
from mcunet.mcunet.tinynas.elastic_nn.modules.dynamic_op import *
from mcunet.mcunet.utils import adjust_bn_according_to_idx, copy_bn, make_divisible, SEModule, MyModule, val2list, \
    get_net_device, build_activation

from collections import OrderedDict
import copy
import torch.nn as nn


class MyModule(nn.Module):

    def forward(self, x):
        raise NotImplementedError

    @property
    def module_str(self):
        raise NotImplementedError

    @property
    def config(self):
        raise NotImplementedError

    @staticmethod
    def build_from_config(config):
        raise NotImplementedError
    
def val2list(val, repeat_time=1):
    if isinstance(val, list) or isinstance(val, np.ndarray):
        return val
    elif isinstance(val, tuple):
        return list(val)
    else:
        return [val for _ in range(repeat_time)]


class DynamicMBConvLayer(MyModule):

    def __init__(self, in_channel_list, out_channel_list,
                 kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False):
        super(DynamicMBConvLayer, self).__init__()

        self.in_channel_list = in_channel_list
        self.out_channel_list = out_channel_list

        self.kernel_size_list = val2list(kernel_size_list, 1)
        self.expand_ratio_list = val2list(expand_ratio_list, 1)

        self.stride = stride
        self.act_func = act_func
        self.use_se = use_se

        # build modules
        max_middle_channel = round(max(self.in_channel_list) * max(self.expand_ratio_list))
        if max(self.expand_ratio_list) == 1:
            self.inverted_bottleneck = None
        else:
            self.inverted_bottleneck = nn.Sequential(OrderedDict([
                ('conv', DynamicPointConv2d(max(self.in_channel_list), max_middle_channel)),
                ('bn', DynamicBatchNorm2d(max_middle_channel)),
                ('act', build_activation(self.act_func, inplace=True)),
            ]))

        self.depth_conv = nn.Sequential(OrderedDict([
            ('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, self.stride)),
            ('bn', DynamicBatchNorm2d(max_middle_channel)),
            ('act', build_activation(self.act_func, inplace=True))
        ]))
        if self.use_se:
            self.depth_conv.add_module('se', DynamicSE(max_middle_channel))

        self.point_linear = nn.Sequential(OrderedDict([
            ('conv', DynamicPointConv2d(max_middle_channel, max(self.out_channel_list))),
            ('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
        ]))

        self.active_kernel_size = max(self.kernel_size_list)
        self.active_expand_ratio = max(self.expand_ratio_list)
        self.active_out_channel = max(self.out_channel_list)

    def forward(self, x):
        in_channel = x.size(1)

        if self.inverted_bottleneck is not None:
            self.inverted_bottleneck.conv.active_out_channel = \
                make_divisible(round(in_channel * self.active_expand_ratio), 8)

        self.depth_conv.conv.active_kernel_size = self.active_kernel_size
        self.point_linear.conv.active_out_channel = self.active_out_channel

        if self.inverted_bottleneck is not None:
            x = self.inverted_bottleneck(x)
        x = self.depth_conv(x)
        x = self.point_linear(x)
        return x

    @property
    def module_str(self):
        if self.use_se:
            return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
        else:
            return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)

    @property
    def config(self):
        return {
            'name': DynamicMBConvLayer.__name__,
            'in_channel_list': self.in_channel_list,
            'out_channel_list': self.out_channel_list,
            'kernel_size_list': self.kernel_size_list,
            'expand_ratio_list': self.expand_ratio_list,
            'stride': self.stride,
            'act_func': self.act_func,
            'use_se': self.use_se,
        }

    @staticmethod
    def build_from_config(config):
        return DynamicMBConvLayer(**config)

    ############################################################################################

    def get_active_subnet(self, in_channel, preserve_weight=True):
        middle_channel = make_divisible(round(in_channel * self.active_expand_ratio), 8)

        # build the new layer
        sub_layer = MBInvertedConvLayer(
            in_channel, self.active_out_channel, self.active_kernel_size, self.stride, self.active_expand_ratio,
            act_func=self.act_func, mid_channels=middle_channel, use_se=self.use_se,
        )
        sub_layer = sub_layer.to(get_net_device(self))

        if not preserve_weight:
            return sub_layer

        # copy weight from current layer
        if sub_layer.inverted_bottleneck is not None:
            sub_layer.inverted_bottleneck.conv.weight.data.copy_(
                self.inverted_bottleneck.conv.conv.weight.data[:middle_channel, :in_channel, :, :]
            )
            copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)

        sub_layer.depth_conv.conv.weight.data.copy_(
            self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
        )
        copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)

        if self.use_se:
            se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=8)
            sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
                self.depth_conv.se.fc.reduce.weight.data[:se_mid, :middle_channel, :, :]
            )
            sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(self.depth_conv.se.fc.reduce.bias.data[:se_mid])

            sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
                self.depth_conv.se.fc.expand.weight.data[:middle_channel, :se_mid, :, :]
            )
            sub_layer.depth_conv.se.fc.expand.bias.data.copy_(self.depth_conv.se.fc.expand.bias.data[:middle_channel])

        sub_layer.point_linear.conv.weight.data.copy_(
            self.point_linear.conv.conv.weight.data[:self.active_out_channel, :middle_channel, :, :]
        )
        copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)

        return sub_layer

    def re_organize_middle_weights(self, expand_ratio_stage=0):
        importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))  # over input ch
        if expand_ratio_stage > 0:
            sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
            sorted_expand_list.sort(reverse=True)
            target_width = sorted_expand_list[expand_ratio_stage]
            target_width = round(max(self.in_channel_list) * target_width)
            importance[target_width:] = torch.arange(0, target_width - importance.size(0), -1)

        sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
        self.point_linear.conv.conv.weight.data = torch.index_select(
            self.point_linear.conv.conv.weight.data, 1, sorted_idx
        )

        adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
        self.depth_conv.conv.conv.weight.data = torch.index_select(
            self.depth_conv.conv.conv.weight.data, 0, sorted_idx
        )

        if self.use_se:
            # se expand: output dim 0 reorganize
            se_expand = self.depth_conv.se.fc.expand
            se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
            se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
            # se reduce: input dim 1 reorganize
            se_reduce = self.depth_conv.se.fc.reduce
            se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
            # middle weight reorganize
            se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
            se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)

            se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
            se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
            se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)

        # TODO if inverted_bottleneck is None, the previous layer should be reorganized accordingly
        if self.inverted_bottleneck is not None:
            adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
            self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
                self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
            )
            return None
        else:
            return sorted_idx


NameError: name 'nn' is not defined