* ShuffleNet
* 논문 : https://arxiv.org/pdf/1707.01083
* MobileNetV1은 Depthwise Separable Conv연산량을 줄이고,
* Conv1x1의 전체 파라미터수에서 차지하는 비율을 FC보다 늘렸음
* ShuffleNet은 Conv1x1의 크기도 줄여보자는 취지의 논문

In [50]:
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.autograd import Variable
from torchinfo import summary

from collections import OrderedDict

<img src = 'https://raw.githubusercontent.com/ech97/save-image-repo/master/image/download.png'>

* 코드는 jaxony의 github를 참조
* https://github.com/jaxony/ShuffleNet 

In [42]:
def channel_shuffle(x, groups):
    b, c, h, w = x.data.size()
    channels_per_group = c // groups
    
    x = x.view(b, groups, channels_per_group, h, w)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(b, -1, h, w)
    return x

def conv1x1(in_channels, out_channels, groups = 1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size = 1,
        groups = groups,
        stride = 1
    )

def conv3x3(in_channels, out_channels, stride = 1, padding = 1, bias = True, groups = 1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size = 3,
        stride = stride,
        padding = padding,
        bias = bias,
        groups = groups
    )

* channel_shuffle()

In [20]:
x = torch.randn((1,3,1,32,32))
y = torch.transpose(x,1,2)
z = y.contiguous()

x.shape, y.shape, z.shape

(torch.Size([1, 3, 1, 32, 32]),
 torch.Size([1, 1, 3, 32, 32]),
 torch.Size([1, 1, 3, 32, 32]))

In [21]:
x1 = torch.zeros((1,1,32,32))
x2 = torch.ones((1,1,32,32))
x = torch.cat((x1,x2), dim = 1)
for i in range(2,24):
    tmp = torch.ones((1,1,32,32)) * i
    x = torch.cat([x, tmp], dim = 1)
print(x)
out = channel_shuffle(x, 3)
print(out)

tensor([[[[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          ...,
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  ...,  0.,  0.,  0.]],

         [[ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          ...,
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  ...,  1.,  1.,  1.]],

         [[ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
          [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
          [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
          ...,
          [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
          [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
          [ 2.,  2.,  2.,  ...,  2.,  2.,  2.]],

         ...,

         [[21., 21., 21.,  ..., 21., 21., 21.],
          [21., 21., 2

In [62]:
class ShuffleUnit(nn.Module):
    '''
    bottleneck channel의 크기: 초기 1x1 conv 수행 시 출력 채널의 1/4
    '''
    def __init__(self, in_channels, out_channels, groups = 3, grouped_conv = True, combine = 'add'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bottleneck_channels = self.out_channels // 4
        self.groups = groups
        self.grouped_conv = grouped_conv
        self.combine = combine
        
        if self.combine == 'add':
            self.depthwise_stride = 1
            self._combine_func = self._add
        elif self.combine == 'concat':
            self.depthwise_stride = 2
            self._combine_func = self._concat
            self.out_channels -= self.in_channels
        else:
            raise ValueError(f'Cannot combine tensors with "{self.combine}"' +
                             f'Only "add" and "concat" are supported.')
        
        self.first_1x1_groups = self.groups if grouped_conv else 1
        self.g_conv_1x1_compress = self._make_grouped_conv1x1(
            self.in_channels,
            self.bottleneck_channels,
            self.first_1x1_groups,
            batch_norm = True,
            relu = True
        )
        self.depthwise_conv3x3 = conv3x3(
            self.bottleneck_channels,
            self.bottleneck_channels,
            stride = self.depthwise_stride,
            groups = self.bottleneck_channels
        )
        self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
        self.g_conv_1x1_expand = self._make_grouped_conv1x1(
            self.bottleneck_channels,
            self.out_channels,
            self.groups,
            batch_norm = True,
            relu = False
        )
    
    @staticmethod
    def _add(x, out):
        return x + out
    
    @staticmethod
    def _concat(x, out):
        return torch.cat((x, out), 1)
    
    def _make_grouped_conv1x1(self, in_channels, out_channels, groups, batch_norm = True, relu = False):
        modules = OrderedDict()
        conv = conv1x1(in_channels, out_channels, groups = groups)
        modules['conv1x1'] = conv
        
        if batch_norm:
            modules['batch_norm'] = nn.BatchNorm2d(out_channels)
        if relu:
            modules['relu'] = nn.ReLU(inplace = True)
        if len(modules) > 1:
            return nn.Sequential(modules)
        else:
            return conv
    
    def forward(self, x):
        residual = x
        
        if self.combine == 'concat':
            residual = F.avg_pool2d(
                residual,
                kernel_size = 3,
                stride = 2,
                padding = 1
            )
        out = self.g_conv_1x1_compress(x)
        out = channel_shuffle(out, self.groups)
        out = self.depthwise_conv3x3(out)
        out = self.bn_after_depthwise(out)
        out = self.g_conv_1x1_expand(out)
        out = self._combine_func(residual, out)
        return F.relu(out)

<img src = 'https://raw.githubusercontent.com/ech97/save-image-repo/master/image/download1.png'>

In [44]:
class ShuffleNet(nn.Module):
    def __init__(self, groups = 3, in_channels = 3, num_classes = 1000):
        super().__init__()
        
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.groups = groups
        self.stage_repeats = [3,7,3]
        
        if groups == 1:
            self.stage_out_channels = [-1, 24, 144, 288, 567]
        elif groups == 2:
            self.stage_out_channels = [-1, 24, 200, 400, 800]
        elif groups == 3:
            self.stage_out_channels = [-1, 24, 240, 480, 960]
        elif groups == 4:
            self.stage_out_channels = [-1, 24, 272, 544, 1088]
        elif groups == 8:
            self.stage_out_channels = [-1, 24, 384, 768, 1536]
        else:
            raise ValueError(f"Only 1,2,3,4,8 is available!!!")
        
        self.conv1 = conv3x3(
            self.in_channels,
            self.stage_out_channels[1],
            stride = 2
        )
        self.maxpool1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        
        self.stage2 = self._make_stage(2)
        self.stage3 = self._make_stage(3)
        self.stage4 = self._make_stage(4)
        
        num_inputs = self.stage_out_channels[-1]
        self.fc = nn.Linear(num_inputs, self.num_classes)
        self._init_layer()
   
    def _init_layer(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode = 'fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std = 0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    
    def _make_stage(self, stage):
        modules = OrderedDict()
        stage_name = f"ShuffleUnit_Stage{stage}"
        
        grouped_conv = stage > 2
        
        first_module = ShuffleUnit(
            self.stage_out_channels[stage - 1],
            self.stage_out_channels[stage],
            groups = self.groups,
            grouped_conv = grouped_conv,
            combine = 'concat'
        )
        modules[stage_name + '_0'] = first_module
        
        for i in range(self.stage_repeats[stage-2]):
            name = stage_name + f"_{i+1}"
            module = ShuffleUnit(
                self.stage_out_channels[stage],
                self.stage_out_channels[stage],
                groups = self.groups,
                grouped_conv = True,
                combine = 'add'
            )
            modules[name] = module
        return nn.Sequential(modules)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.maxpool1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = F.avg_pool2d(out, out.data.size()[-2:])
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [63]:
models = ShuffleNet(groups = 1)
summary(models, (1,3,224,224))

  init.kaiming_normal(m.weight, mode='fan_out')
  init.constant(m.bias, 0)
  init.constant(m.weight, 1)
  init.constant(m.bias, 0)
  init.normal(m.weight, std=0.001)
  init.constant(m.bias, 0)


Layer (type:depth-idx)                   Output Shape              Param #
ShuffleNet                               [1, 1000]                 --
├─Conv2d: 1-1                            [1, 24, 112, 112]         672
├─MaxPool2d: 1-2                         [1, 24, 56, 56]           --
├─Sequential: 1-3                        [1, 144, 28, 28]          --
│    └─ShuffleUnit: 2-1                  [1, 144, 28, 28]          --
│    │    └─Sequential: 3-1              [1, 36, 56, 56]           972
│    │    └─Conv2d: 3-2                  [1, 36, 28, 28]           360
│    │    └─BatchNorm2d: 3-3             [1, 36, 28, 28]           72
│    │    └─Sequential: 3-4              [1, 120, 28, 28]          4,680
│    └─ShuffleUnit: 2-2                  [1, 144, 28, 28]          --
│    │    └─Sequential: 3-5              [1, 36, 28, 28]           5,292
│    │    └─Conv2d: 3-6                  [1, 36, 28, 28]           360
│    │    └─BatchNorm2d: 3-7             [1, 36, 28, 28]           72
│    

In [64]:
models = ShuffleNet(groups = 8)
summary(models, (1,3,224,224))

  init.kaiming_normal(m.weight, mode='fan_out')
  init.constant(m.bias, 0)
  init.constant(m.weight, 1)
  init.constant(m.bias, 0)
  init.normal(m.weight, std=0.001)
  init.constant(m.bias, 0)


Layer (type:depth-idx)                   Output Shape              Param #
ShuffleNet                               [1, 1000]                 --
├─Conv2d: 1-1                            [1, 24, 112, 112]         672
├─MaxPool2d: 1-2                         [1, 24, 56, 56]           --
├─Sequential: 1-3                        [1, 384, 28, 28]          --
│    └─ShuffleUnit: 2-1                  [1, 384, 28, 28]          --
│    │    └─Sequential: 3-1              [1, 96, 56, 56]           2,592
│    │    └─Conv2d: 3-2                  [1, 96, 28, 28]           960
│    │    └─BatchNorm2d: 3-3             [1, 96, 28, 28]           192
│    │    └─Sequential: 3-4              [1, 360, 28, 28]          5,400
│    └─ShuffleUnit: 2-2                  [1, 384, 28, 28]          --
│    │    └─Sequential: 3-5              [1, 96, 28, 28]           4,896
│    │    └─Conv2d: 3-6                  [1, 96, 28, 28]           960
│    │    └─BatchNorm2d: 3-7             [1, 96, 28, 28]           192
│