In [114]:
import torch
from abc import ABC, abstractmethod


def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width


class ABCConv2d(ABC):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

    def set_kernel(self, kernel):
        self.kernel = kernel

    @abstractmethod
    def __call__(self, input_tensor):
        pass


class Conv2d(ABCConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                      stride, padding=0, bias=False)

    def set_kernel(self, kernel):
        self.conv2d.weight.data = kernel

    def __call__(self, input_tensor):
        return self.conv2d(input_tensor)


def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
    out_channels = kernel.shape[0]
    in_channels = kernel.shape[1]
    kernel_size = kernel.shape[2]

    layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
    layer.set_kernel(kernel)

    return layer(input_matrix)


def test_conv2d_layer(conv2d_layer_class, batch_size=2,
                      input_height=4, input_width=4, stride=2):
    kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

    in_channels = kernel.shape[1]

    input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)

    custom_conv2d_out = create_and_call_conv2d_layer(
        conv2d_layer_class, stride, kernel, input_tensor)
    conv2d_out = create_and_call_conv2d_layer(
        Conv2d, stride, kernel, input_tensor)
    print("conv2d_out.shape", conv2d_out)

    # return torch.allclose(custom_conv2d_out, conv2d_out) \
    #          and (custom_conv2d_out.shape == conv2d_out.shape)


class Conv2dMatrix(ABCConv2d):
    # Функция преобразование кернела в матрицу нужного вида.
    def _unsqueeze_kernel(self, torch_input, output_height, output_width):
        kernel_unsqueezed = torch.nn.functional.pad(torch_input, (0, 0), 
                                                    "constant", 0).flatten(start_dim=1)
        print("kernel_unsqueezed.shape:", kernel_unsqueezed.shape)
        print("kernel_unsqueezed: \n", kernel_unsqueezed)
        return kernel_unsqueezed

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
            = calc_out_shape(
                input_matrix_shape=torch_input.shape,
                out_channels=self.kernel.shape[0],
                kernel_size=self.kernel.shape[2],
                stride=self.stride,
                padding=0)

        kernel_unsqueezed = self._unsqueeze_kernel(torch_input, output_height, output_width)
        result = kernel_unsqueezed @ torch_input.view((batch_size, -1)).permute(1, 0)
        print("torch_input.view((batch_size, -1)).shape \n", torch_input.view((batch_size, -1)).shape)
        print("torch_input.view((batch_size, -1)) \n", torch_input.view((batch_size, -1)))
        print("torch_input.view((batch_size, -1)).permute(1, 0) \n", torch_input.view((batch_size, -1)).permute(1, 0))
        print("result.shape:", result.shape)
        print("result: \n", result)
        print("result.permute(1, 0).shape:", result.permute(1, 0).shape)
        return result.permute(1, 0).view((batch_size, self.out_channels,
                                           output_height, output_width))

# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dMatrix))

kernel_unsqueezed.shape: torch.Size([2, 48])
kernel_unsqueezed: 
 tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
         28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
         42., 43., 44., 45., 46., 47.],
        [48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61.,
         62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75.,
         76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89.,
         90., 91., 92., 93., 94., 95.]])
torch_input.view((batch_size, -1)).shape 
 torch.Size([2, 48])
torch_input.view((batch_size, -1)) 
 tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
         28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
         42., 43., 44., 45., 46., 47.],
  

RuntimeError: shape '[2, 1, 1, 1]' is invalid for input of size 4

In [78]:
kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

in_channels = kernel.shape[1]

# input_tensor = torch.arange(0, 2 * in_channels *
#                                 4 * 4,
#                                 out=torch.FloatTensor()) \
#         .reshape(2, in_channels, 4, 4)
print(kernel.shape)
# print(input_tensor.shape)

torch.Size([1, 3, 3, 3])


In [88]:
# t4d = torch.ones(3, 3, 4, 2)
# p1d = (1, 1) # pad last dim by 1 on each side
out = torch.nn.functional.pad(kernel, (0, 0), "constant", 0)  # effectively zero padding
print(out.size())
print(out)

torch.Size([1, 3, 3, 3])
tensor([[[[ 0.,  1.,  0.],
          [ 1.,  2.,  1.],
          [ 0.,  1.,  0.]],

         [[ 1.,  2.,  1.],
          [ 0.,  3.,  3.],
          [ 0.,  1., 10.]],

         [[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]]])


In [None]:

print(torch.flatten(out).shape)
print(torch.flatten(out))
print(torch.flatten(out, start_dim=1).shape)
print(torch.flatten(out, start_dim=1))
print(torch.allclose(torch.flatten(out), 
                    torch.flatten(out, start_dim=1)))

torch.Size([27])
tensor([ 0.,  1.,  0.,  1.,  2.,  1.,  0.,  1.,  0.,  1.,  2.,  1.,  0.,  3.,
         3.,  0.,  1., 10., 10., 11., 12., 13., 14., 15., 16., 17., 18.])
torch.Size([1, 27])
tensor([[ 0.,  1.,  0.,  1.,  2.,  1.,  0.,  1.,  0.,  1.,  2.,  1.,  0.,  3.,
          3.,  0.,  1., 10., 10., 11., 12., 13., 14., 15., 16., 17., 18.]])
True
