<a href="https://colab.research.google.com/github/volkov-maxim/nn_cv_course_samsung_stepik/blob/main/5_2_Implementation_of_convolution_layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import torch
from abc import ABC, abstractmethod

In [None]:
# Создаем входной массив из двух изображений RGB 3*3
input_images = torch.tensor(
      [[[[0,  1,  2],
         [3,  4,  5],
         [6,  7,  8]],

        [[9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]],


       [[[27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44]],

        [[45, 46, 47],
         [48, 49, 50],
         [51, 52, 53]]]])

In [None]:
correct_padded_images = torch.tensor(
       [[[[0.,  0.,  0.,  0.,  0.],
          [0.,  0.,  1.,  2.,  0.],
          [0.,  3.,  4.,  5.,  0.],
          [0.,  6.,  7.,  8.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0.,  9., 10., 11.,  0.],
          [0., 12., 13., 14.,  0.],
          [0., 15., 16., 17.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 18., 19., 20.,  0.],
          [0., 21., 22., 23.,  0.],
          [0., 24., 25., 26.,  0.],
          [0.,  0.,  0.,  0.,  0.]]],


        [[[0.,  0.,  0.,  0.,  0.],
          [0., 27., 28., 29.,  0.],
          [0., 30., 31., 32.,  0.],
          [0., 33., 34., 35.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 36., 37., 38.,  0.],
          [0., 39., 40., 41.,  0.],
          [0., 42., 43., 44.,  0.],
          [0.,  0.,  0.,  0.,  0.]],

         [[0.,  0.,  0.,  0.,  0.],
          [0., 45., 46., 47.,  0.],
          [0., 48., 49., 50.,  0.],
          [0., 51., 52., 53.,  0.],
          [0.,  0.,  0.,  0.,  0.]]]])

In [None]:
def get_padding2d(input_images):
    # добавить нулей с четырех сторон каждого изображения
    padded_images = F.pad(input=input_images.to(torch.float), pad=(1, 1, 1, 1), mode='constant', value=0.)
    return padded_images

In [None]:
# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(torch.allclose(get_padding2d(input_images), correct_padded_images))

True


In [None]:
def calc_out_shape_my(input_matrix_shape, out_channels, kernel_size, stride, padding):
    # напишите тут код, вычисляющий выходную размерность
    if stride > 0:
      out_shape = [input_matrix_shape[0], 
                   out_channels, 
                   ((input_matrix_shape[2] + 2*padding - kernel_size) // stride) + 1, 
                   ((input_matrix_shape[3] + 2*padding - kernel_size) // stride) + 1]
    else:
        raise Exception("Stride must be > 0")

    return out_shape

In [None]:
print(np.array_equal(
    calc_out_shape_my(input_matrix_shape=[2, 3, 10, 10],
                   out_channels=10,
                   kernel_size=3,
                   stride=1,
                   padding=0),
    [2, 10, 8, 8]))

True


In [None]:
def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width


# абстрактный класс для сверточного слоя
class ABCConv2d(ABC):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

    def set_kernel(self, kernel):
        self.kernel = kernel

    @abstractmethod
    def __call__(self, input_tensor):
        pass


# класс-обертка над torch.nn.Conv2d для унификации интерфейса
class Conv2d(ABCConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                      stride, padding=0, bias=False)

    def set_kernel(self, kernel):
        self.conv2d.weight.data = kernel

    def __call__(self, input_tensor):
        return self.conv2d(input_tensor)


# функция, создающая объект класса cls и возвращающая свертку от input_matrix
def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
    out_channels = kernel.shape[0]
    in_channels = kernel.shape[1]
    kernel_size = kernel.shape[2]

    layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
    layer.set_kernel(kernel)

    return layer(input_matrix)


# Функция, тестирующая класс conv2d_cls.
# Возвращает True, если свертка совпадает со сверткой с помощью torch.nn.Conv2d.
def test_conv2d_layer(conv2d_layer_class, batch_size=2,
                      input_height=4, input_width=4, stride=2):
    kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

    in_channels = kernel.shape[1]

    input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)

    custom_conv2d_out = create_and_call_conv2d_layer(
        conv2d_layer_class, stride, kernel, input_tensor)
    conv2d_out = create_and_call_conv2d_layer(
        Conv2d, stride, kernel, input_tensor)
    return torch.allclose(custom_conv2d_out, conv2d_out) \
              and (custom_conv2d_out.shape == conv2d_out.shape)

In [None]:
print(test_conv2d_layer(Conv2d))

True


In [None]:
# Сверточный слой через циклы.
class Conv2dLoop(ABCConv2d):
    def __call__(self, input_tensor):

        batch_size, out_channels, output_height, output_width = calc_out_shape(input_tensor.shape, 
                                                                                 self.out_channels, 
                                                                                 self.kernel_size, 
                                                                                 self.stride, 
                                                                                 0)

        output_tensor = torch.zeros(batch_size, out_channels, output_height, output_width)

        for input_tensor_index in range(batch_size):
          
          for out_channel in range(self.out_channels):
            it_row, ot_row = 0, 0

            while it_row + self.kernel_size <= input_tensor.shape[2]: 
              it_col, ot_col = 0, 0

              while it_col + self.kernel_size <= input_tensor.shape[3]: 
                input_tensor_masked = input_tensor[input_tensor_index, :, it_row:it_row + self.kernel_size, it_col:it_col + self.kernel_size]
                output_tensor[input_tensor_index, out_channel, ot_row, ot_col] = torch.sum(torch.mul(input_tensor_masked, 
                                                                                                     self.kernel[out_channel, :, :, :])).item()

                it_col += self.stride
                ot_col += 1

              it_row += self.stride
              ot_row += 1
            
        return output_tensor



In [None]:
# Корректность реализации определится в сравнении со стандартным слоем из pytorch.
# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dLoop))

True


In [None]:
# Свёрточный слой через матрицы (способ 1)
class Conv2dMatrix(ABCConv2d):
    # Функция преобразование кернела в матрицу нужного вида.
    def _unsqueeze_kernel(self, torch_input, output_height, output_width):
        W = torch.zeros(output_height * output_width, torch_input.shape[3]**2)
        
        # Walk through batch of kernels.
        for k_b in range(self.kernel.shape[0]):

            # Walk through kernel channels.
            for k_ch in range(self.kernel.shape[1]):
                expanded_kernel = torch.zeros(torch_input.shape[2], torch_input.shape[3])
                expanded_kernel[:self.kernel.shape[2], :self.kernel.shape[3]] = self.kernel[k_b, k_ch, :, :]  
                expanded_kernel = expanded_kernel.view((1, -1))[0] # Matrix flattening.
                
                # Make W' matrix.
                for k_row in range(output_height * output_width): 
                    W[k_row, (k_row % 2)+(k_row // 2)*torch_input.shape[3]:expanded_kernel.shape[0]] = expanded_kernel[:expanded_kernel.shape[0]-(k_row % 2)-(k_row // 2)*torch_input.shape[3]]
                
                if k_ch == 0:
                    kernel_unsqueezed_rgb = W.clone()
                else:
                    kernel_unsqueezed_rgb = torch.cat((kernel_unsqueezed_rgb, W), 1)

            if k_b == 0:
                kernel_unsqueezed = kernel_unsqueezed_rgb.clone()
            else:
                kernel_unsqueezed = torch.cat((kernel_unsqueezed, kernel_unsqueezed_rgb), 0)

        return kernel_unsqueezed

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
            = calc_out_shape(
                input_matrix_shape=torch_input.shape,
                out_channels=self.kernel.shape[0],
                kernel_size=self.kernel.shape[2],
                stride=self.stride,
                padding=0)

        kernel_unsqueezed = self._unsqueeze_kernel(torch_input, output_height, output_width)
        result = kernel_unsqueezed @ torch_input.view((batch_size, -1)).permute(1, 0)

        return result.permute(1, 0).view((batch_size, self.out_channels,
                                          output_height, output_width))

In [None]:
# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dMatrix))

True


In [None]:
# Свёрточный слой через матрицы (способ 2)
class Conv2dMatrixV2(ABCConv2d):
    # Функция преобразования кернела в нужный формат.
    def _convert_kernel(self):
        # Walk through batch of kernels.
        for k_b in range(self.kernel.shape[0]):
            # Walk through kernel channels.
            for k_ch in range(self.kernel.shape[1]):
                if k_ch == 0:
                    kernel_unsqueezed_rgb = self.kernel[k_b, k_ch, :, :].view((1, -1)).clone()
                else:
                    kernel_unsqueezed_rgb = torch.cat((kernel_unsqueezed_rgb, 
                                                      self.kernel[k_b, k_ch, :, :].view((1, -1)).clone()), 1)

            if k_b == 0:
                converted_kernel = kernel_unsqueezed_rgb.clone()
            else:
                converted_kernel = torch.cat((converted_kernel, kernel_unsqueezed_rgb), 0)

        return converted_kernel

    # Функция преобразования входа в нужный формат.
    def _convert_input(self, torch_input, output_height, output_width):
        # Walk through batch of input.
        for ti_b in range(torch_input.shape[0]):
            # Walk through input channels.
            for ti_ch in range(torch_input.shape[1]):
                for ti_row in range(output_height):
                    for ti_col in range(output_width):
                        
                        if ti_col == 0:
                            converted_masks = torch.reshape(torch_input[ti_b, ti_ch, ti_row:ti_row+self.kernel_size, ti_col:ti_col+self.kernel_size], (-1,1))
                        else:
                            converted_masks = torch.cat((converted_masks, 
                                                      torch.reshape(torch_input[ti_b, ti_ch, ti_row:ti_row+self.kernel_size, ti_col:ti_col+self.kernel_size], (-1,1))), 
                                                    1)    
                    if ti_row == 0:
                        converted_channel = converted_masks.clone()
                    else:
                        converted_channel = torch.cat((converted_channel, converted_masks), 1)
                
                if ti_ch == 0:
                    converted_channels = converted_channel.clone()
                else:
                    converted_channels = torch.cat((converted_channels, converted_channel), 0)

            if ti_b == 0:
                converted_input = converted_channels.clone()
            else:
                converted_input = torch.cat((converted_input, converted_channels), 1)

        return converted_input

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
            = calc_out_shape(
                input_matrix_shape=torch_input.shape,
                out_channels=self.kernel.shape[0],
                kernel_size=self.kernel.shape[2],
                stride=self.stride,
                padding=0)

        converted_kernel = self._convert_kernel()
        converted_input = self._convert_input(torch_input, output_height, output_width)

        conv2d_out_alternative_matrix_v2 = converted_kernel @ converted_input
        return conv2d_out_alternative_matrix_v2.transpose(1,0).view(torch_input.shape[0],
                                                     self.out_channels, 
                                                     output_height,
                                                     output_width)

In [None]:
# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dMatrixV2))

True
