In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Подробности по ДЗ:
1. Нужно реализовать флаг make_downsample - увеличивает кол-во фильтровать вдвое, размер уменьшается вдвое по высоте и ширине - например (64, 16, 16) -> (128, 8, 8)
2. Нужно реализовать флаг use_skip_connection - если он включен, то на выходе блока добавляется X со входа - иначе блок работает как обычная сесть

Особенности downsample
1. Уменьшать размер входного изображения надо посредством conv3x3 со stride=2
2. Уменьшаться размер должен первой конволюцией 3x3 в блоке
3. В Bottleneck версии - кол-во фильтров меняется первым bottleneck слоем
4. В случае если нет флага use_skip_connection не нужно использовать слой для downsample исходного изображения, так как оно не будет дальше использоваться


Общие рекомендации по построению ResNet сетей:
1. После каждой конволюции идет BatchNorm и Relu слои
2. В конце ResNet блока после суммирования идет Relu слой
3. Конволюциооные слои, включая слои Bottleneck не используют bias (bias=False) - опциональное

Блоки строятся на основании статьи https://arxiv.org/abs/1512.03385

Tutorial по Pytorch https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

In [2]:
from operator import mul
from functools import reduce

# небольшой код по подсчету памяти и количества параметров в сети
MODULES_STAT=[]

def module_forward_hook(module, input, output):
    weight = module.weight.size() if not isinstance(module, torch.nn.modules.MaxPool2d) else (0, 0, 0, 0)
    print("Forward", module)
    MODULES_STAT.append((module, output.size(), weight))
    
def register_hook(module):
    for item in module.children(): 
        if type(item) in [nn.modules.conv.Conv2d, nn.modules.MaxPool2d, nn.modules.Linear]:
            print("RegisterHook", item)
            item.register_forward_hook(module_forward_hook)
                 
def features_mem_and_params(input_tenzor):
    input_size = input_tenzor.size()
    total_param = 0
    total_mem =  reduce(mul,(input_size))
    print( "%02d" % 0,
          'INPUT',
          "memory",
          "%dx%dx%d=%d" % (input_size[1], input_size[2], input_size[3], reduce(mul,(input_size))),
          "parameters", "%dx%dx%d=%d"%(0, 0, 0 , 0)
         ) 
    for i, stat in enumerate(MODULES_STAT):
        module_name = str(stat).split('(')[1]
        total_param += reduce(mul,(stat[2]))
        total_mem   += reduce(mul,(stat[1]))
        
        if 'Linear' in module_name:
            print( "%02d"%(i+1),'FC',"memory", "%dx%d=%d"%(stat[1][0], stat[1][1], reduce(mul,(stat[1]))),
               "parameters", "%dx%d=%d"%(stat[2][0], stat[2][1] , reduce(mul,(stat[2]))))
        else:    
            print( "%02d"%(i+1),module_name,"memory", "%dx%dx%d=%d"%(stat[1][1], stat[1][2], stat[1][3], reduce(mul,(stat[1]))),
               "parameters", "%dx%dx%dx%d=%d"%(stat[2][0], stat[2][1], stat[2][2], stat[2][3] , reduce(mul,(stat[2]))))
    print()
    print ("Total_mem: %d * 4 = %d" % (total_mem, total_mem * 4))
    print ("Total params: %d" % total_param, "Total_mem: %d" % total_mem)          
    return (total_param, total_mem)

In [28]:
DOWNSAMPLE_COEF = 2

def conv3x3(a_in_planes, a_out_planes, a_stride=1):
    """
    Основной строительный блок конволюций для ResNet
    Включает в себя padding=1 - чтобы размерность сохранялась после его применения
    """
    return nn.Conv2d(a_in_planes, a_out_planes,  stride=a_stride,
                     kernel_size=3, padding=1, bias=False)

def x_downsample(a_in_channels):
     return nn.Conv2d(a_in_channels, 
               a_in_channels*DOWNSAMPLE_COEF,
               kernel_size=1,
               stride=2,
               bias=False)

In [145]:
class CifarResidualBlock(nn.Module):
    def __init__(self, a_in_channels, make_downsample=False, use_skip_connection=True):
        super(CifarResidualBlock, self).__init__()
        self.use_skip_connection = use_skip_connection
        self.make_downsample=make_downsample
        
        if make_downsample: coef = DOWNSAMPLE_COEF
        else: coef = 1  
        self.a_in_channels=a_in_channels    
        ### TODO - нужно описать используемые блоки
        #a_in_channels*coef-количество каналов на выходе
        self.conv1=conv3x3(a_in_planes=a_in_channels,a_out_planes=a_in_channels*coef,a_stride=coef)
        self.bn1=nn.BatchNorm2d(a_in_channels*coef)
        self.relu=nn.ReLU()
        self.conv2=conv3x3(a_in_planes=a_in_channels*coef,a_out_planes=a_in_channels*coef)
        self.bn2=nn.BatchNorm2d(a_in_channels*coef)
        self.relu1=nn.ReLU()
        self.down=x_downsample(a_in_channels)

            
    def forward(self, x):
        ###TODO - описать forward блок с учетом флагов make_downsample и use_skip_connection
        res=self.conv1(x)
        res=self.bn1(res)
        res=self.relu(res)
        res=self.conv2(res)
        res=self.bn2(res)
        res=self.relu1(res)
        
        if self.use_skip_connection:
            if self.make_downsample:
#                 print(x.shape[0])
                res+=self.down(x)
                res=self.relu(res)
            else:
                res+=x
                res=self.relu(res)

        return res

In [146]:
### Test 9

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, make_downsample=False, use_skip_connection=False)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (69632, 40960))

RegisterHook Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
Forward Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Forward Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 64x8x8=4096 parameters 64x256x1x1=16384
02 Conv2d memory 64x8x8=4096 parameters 64x64x3x3=36864
03 Conv2d memory 256x8x8=16384 parameters 256x64x1x1=16384

Total_mem: 40960 * 4 = 163840
Total params: 69632 Total_mem: 40960
69632 40960


In [147]:
### Test 9

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, make_downsample=False, use_skip_connection=False)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (69632, 40960))

RegisterHook Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
Forward Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Forward Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 64x8x8=4096 parameters 64x256x1x1=16384
02 Conv2d memory 64x8x8=4096 parameters 64x64x3x3=36864
03 Conv2d memory 256x8x8=16384 parameters 256x64x1x1=16384

Total_mem: 40960 * 4 = 163840
Total params: 69632 Total_mem: 40960
69632 40960


In [148]:
### Test 1
x = torch.ones(1, 3, 32, 32)*100
print("Input size :\t\t", x.size())

first_conv = nn.Conv2d(3, 16, 3, padding=1, bias=False)
x = first_conv(x)
print("After first layers:\t", x.size())

block = CifarResidualBlock(16, make_downsample=False, use_skip_connection=True)
x = block(x)
print("After ResBlock layers:\t", x.size())

assert(x.size() == torch.Size((1, 16, 32, 32)))

Input size :		 torch.Size([1, 3, 32, 32])
After first layers:	 torch.Size([1, 16, 32, 32])
After ResBlock layers:	 torch.Size([1, 16, 32, 32])


In [149]:
### Test 2
x = torch.ones(1, 3, 32, 32)*100
print("Input size :\t\t", x.size())

first_conv = nn.Conv2d(3, 16, 3, padding=1, bias=False)
x = first_conv(x)
print("After first layers:\t", x.size())

block = CifarResidualBlock(16, make_downsample=True, use_skip_connection=True)
x = block(x)
print("After ResBlock layers:\t", x.size())

assert(x.size() == torch.Size((1, 32, 16, 16)))

Input size :		 torch.Size([1, 3, 32, 32])
After first layers:	 torch.Size([1, 16, 32, 32])
After ResBlock layers:	 torch.Size([1, 32, 16, 16])


In [150]:
### Test 2
x = torch.ones(1, 3, 32, 32)*100
print("Input size :\t\t", x.size())

first_conv = nn.Conv2d(3, 16, 3, padding=1, bias=False)
x = first_conv(x)
print("After first layers:\t", x.size())

block = CifarResidualBlock(16, make_downsample=True, use_skip_connection=True)
x = block(x)
print("After ResBlock layers:\t", x.size())

assert(x.size() == torch.Size((1, 32, 16, 16)))

Input size :		 torch.Size([1, 3, 32, 32])
After first layers:	 torch.Size([1, 16, 32, 32])
After ResBlock layers:	 torch.Size([1, 32, 16, 16])


In [151]:
### Test 3
x = torch.ones(1, 16, 32, 32)
block = CifarResidualBlock(16, make_downsample=False, use_skip_connection=True)
x = block(x)
print(x.size()[1]*x.size()[2]*x.size()[3], x.sum())
assert(x.sum() > 10000)

16384 tensor(18376.6348, grad_fn=<SumBackward0>)


In [152]:
### Test 4
x = torch.ones(1, 16, 32, 32)
block = CifarResidualBlock(16, make_downsample=False, use_skip_connection=False)
x = block(x)
print(x.size()[1]*x.size()[2]*x.size()[3], x.sum())
assert(x.sum() < 5000)

16384 tensor(2405.5618, grad_fn=<SumBackward0>)


In [153]:
### Test 5

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=True, use_skip_connection=True)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (3670016, 40960))

RegisterHook Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
RegisterHook Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Forward Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Forward Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 512x4x4=8192 parameters 512x256x3x3=1179648
02 Conv2d memory 512x4x4=8192 parameters 512x512x3x3=2359296
03 Conv2d memory 512x4x4=8192 parameters 512x256x1x1=131072

Total_mem: 40960 * 4 = 163840
Total params: 3670016 Total_mem: 40960
3670016 40960


In [154]:
### Test 6

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=False, use_skip_connection=True)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (1179648, 49152))

RegisterHook Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Forward Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 256x8x8=16384 parameters 256x256x3x3=589824
02 Conv2d memory 256x8x8=16384 parameters 256x256x3x3=589824

Total_mem: 49152 * 4 = 196608
Total params: 1179648 Total_mem: 49152
1179648 49152


In [155]:
### Test 7

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=True, use_skip_connection=False)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (3538944, 32768))

RegisterHook Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
RegisterHook Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Forward Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 512x4x4=8192 parameters 512x256x3x3=1179648
02 Conv2d memory 512x4x4=8192 parameters 512x512x3x3=2359296

Total_mem: 32768 * 4 = 131072
Total params: 3538944 Total_mem: 32768
3538944 32768


In [156]:
### Test 8

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBlock(256, make_downsample=False, use_skip_connection=False)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (1179648, 49152))

RegisterHook Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Forward Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 256x8x8=16384 parameters 256x256x3x3=589824
02 Conv2d memory 256x8x8=16384 parameters 256x256x3x3=589824

Total_mem: 49152 * 4 = 196608
Total params: 1179648 Total_mem: 49152
1179648 49152


In [158]:
class CifarResidualBottleneckBlock(nn.Module):
    
    BOTTLENECK_COEF = 4
    
    def __init__(self, a_in_channels, make_downsample=False, use_skip_connection=True):
        super(CifarResidualBottleneckBlock, self).__init__()
        self.use_skip_connection = use_skip_connection
        self.make_downsample=make_downsample
        if make_downsample: coef = DOWNSAMPLE_COEF
        else: coef = 1  
                        
        ### TODO - нужно описать используемые блоки
#         pass
        bottle_neck_chanels=a_in_channels*coef//self.BOTTLENECK_COEF
        self.conv1=nn.Conv2d(a_in_channels, bottle_neck_chanels, kernel_size=(1, 1),
                             stride=(1, 1),
                             padding=(0,0),
                             bias=False)

        self.bn1=nn.BatchNorm2d(bottle_neck_chanels)
        self.relu=nn.ReLU()
        self.conv2=conv3x3(a_in_planes=bottle_neck_chanels,
                           a_out_planes=bottle_neck_chanels,
                           a_stride=coef)
        self.bn2=nn.BatchNorm2d(bottle_neck_chanels)
        self.conv3=nn.Conv2d(bottle_neck_chanels,
                             a_in_channels*coef,
                             kernel_size=(1, 1),
                             stride=(1, 1),
                             bias=False)
        self.bn3=nn.BatchNorm2d(a_in_channels*coef)
        self.down=x_downsample(a_in_channels)
            
    def forward(self, x):
        ###TODO - описать forward блок с учетом флагов make_downsample и use_skip_connection
        res=self.conv1(x)
        res=self.bn1(res)
        res=self.relu(res)
        res=self.conv2(res)
        res=self.bn2(res)
        res=self.relu(res)
        res=self.conv3(res)
        res=self.bn3(res)
        res=self.relu(res)
        
        if self.use_skip_connection:
            if self.make_downsample:
#                 print(x.shape[0])
                res+=self.down(x)
                res=self.relu(res)
            else:
                res+=x
#                 res=self.relu(res)

        return res

In [159]:
### Test 10

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, make_downsample=False, use_skip_connection=True)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (69632, 40960))

RegisterHook Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
RegisterHook Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
Forward Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Forward Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 64x8x8=4096 parameters 64x256x1x1=16384
02 Conv2d memory 64x8x8=4096 parameters 64x64x3x3=36864
03 Conv2d memory 256x8x8=16384 parameters 256x64x1x1=16384

Total_mem: 40960 * 4 = 163840
Total params: 69632 Total_mem: 40960
69632 40960


In [160]:
### Test 11

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, make_downsample=True, use_skip_connection=False)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (245760, 34816))

RegisterHook Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
RegisterHook Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
Forward Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Forward Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 128x8x8=8192 parameters 128x256x1x1=32768
02 Conv2d memory 128x4x4=2048 parameters 128x128x3x3=147456
03 Conv2d memory 512x4x4=8192 parameters 512x128x1x1=65536

Total_mem: 34816 * 4 = 139264
Total params: 245760 Total_mem: 34816
245760 34816


In [161]:
### Test 12

MODULES_STAT=[]

input = torch.ones(1, 256, 8, 8)
block = CifarResidualBottleneckBlock(256, make_downsample=True, use_skip_connection=True)
register_hook(block)
out = block(input)

params, memory = features_mem_and_params(input)
print(params, memory)
assert((params, memory) == (376832, 43008))

RegisterHook Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
RegisterHook Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
RegisterHook Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
Forward Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
Forward Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Forward Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
Forward Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
00 INPUT memory 256x8x8=16384 parameters 0x0x0=0
01 Conv2d memory 128x8x8=8192 parameters 128x256x1x1=32768
02 Conv2d memory 128x4x4=2048 parameters 128x128x3x3=147456
03 Conv2d memory 512x4x4=8192 parameters 512x128x1x1=65536
04 Conv2d memory 512x4x4=8192 parameters 512x256x1x1=131072

Total_mem: 43008 * 4 = 172032
Total params: 376832 Total_mem: 43008
376832 43008
