In [66]:
import torch
from abc import ABC, abstractmethod


def calc_out_shape(input_matrix_shape, out_channels, kernel_size, stride, padding):
    batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width


class ABCConv2d(ABC):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

    def set_kernel(self, kernel):
        self.kernel = kernel

    @abstractmethod
    def __call__(self, input_tensor):
        pass


def create_and_call_conv2d_layer(conv2d_layer_class, stride, kernel, input_matrix):
    out_channels = kernel.shape[0]
    in_channels = kernel.shape[1]
    kernel_size = kernel.shape[2]

    layer = conv2d_layer_class(in_channels, out_channels, kernel_size, stride)
    layer.set_kernel(kernel)

    return layer(input_matrix)


class Conv2d(ABCConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size,
                                      stride, padding=0, bias=False)

    def set_kernel(self, kernel):
        self.conv2d.weight.data = kernel

    def __call__(self, input_tensor):
        return self.conv2d(input_tensor)


def test_conv2d_layer(conv2d_layer_class, batch_size=2,
                      input_height=4, input_width=4, stride=2):
    kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

    in_channels = kernel.shape[1]

    input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)

    custom_conv2d_out = create_and_call_conv2d_layer(
        conv2d_layer_class, stride, kernel, input_tensor)
    conv2d_out = create_and_call_conv2d_layer(
        Conv2d, stride, kernel, input_tensor)

    return torch.allclose(custom_conv2d_out, conv2d_out) \
             and (custom_conv2d_out.shape == conv2d_out.shape)


class Conv2dMatrixV2(ABCConv2d):
    # Функция преобразования кернела в нужный формат.
    def _convert_kernel(self):
        converted_kernel = self.kernel.flatten(1, 2).T # Реализуйте преобразование кернела.
        return converted_kernel

    # Функция преобразования входа в нужный формат.
    def _convert_input(self, torch_input, output_height, output_width):
        converted_input = torch_input.permute(1, 2, 0, 3).contiguous().view(-1) # Реализуйте преобразование входа.
        return converted_input

    def __call__(self, torch_input):
        batch_size, out_channels, output_height, output_width\
            = calc_out_shape(
                input_matrix_shape=torch_input.shape,
                out_channels=self.kernel.shape[0],
                kernel_size=self.kernel.shape[2],
                stride=self.stride,
                padding=0)

        converted_kernel = self._convert_kernel()
        converted_input = self._convert_input(torch_input, output_height, output_width)

        conv2d_out_alternative_matrix_v2 = converted_kernel @ converted_input
        return conv2d_out_alternative_matrix_v2.transpose(1,0).view(torch_input.shape[0],
                                                     self.out_channels, 
                                                     output_height,
                                                     output_width)

# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
print(test_conv2d_layer(Conv2dMatrixV2))

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [3, 1] but got: [3, 96].

In [64]:
kernel = torch.tensor(
                      [[[[0., 1, 0],
                         [1,  2, 1],
                         [0,  1, 0]],

                        [[1, 2, 1],
                         [0, 3, 3],
                         [0, 1, 10]],

                        [[10, 11, 12],
                         [13, 14, 15],
                         [16, 17, 18]]]])

converted_kernel = kernel.flatten(0, 2)
print("converted_kernel.shape:", converted_kernel.shape)
print(converted_kernel)

converted_kernel.shape: torch.Size([9, 3])
tensor([[ 0.,  1.,  0.],
        [ 1.,  2.,  1.],
        [ 0.,  1.,  0.],
        [ 1.,  2.,  1.],
        [ 0.,  3.,  3.],
        [ 0.,  1., 10.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])


In [43]:
batch_size=2
in_channels=3
input_height=4
input_width=4

out_channels=3

def calc_out_shape(input_matrix_shape=4, out_channels=3,\
                   kernel_size=3, stride=2, padding=0):  
#     batch_size, channels_count, input_height, input_width = input_matrix_shape
    output_height = (input_height + 2 * padding - (kernel_size - 1) - 1) // stride + 1
    output_width = (input_width + 2 * padding - (kernel_size - 1) - 1) // stride + 1

    return batch_size, out_channels, output_height, output_width

print(calc_out_shape(batch_size, in_channels, input_height, input_width))

input_tensor = torch.arange(0, batch_size * in_channels *
                                input_height * input_width,
                                out=torch.FloatTensor()) \
        .reshape(batch_size, in_channels, input_height, input_width)

print("input_tensor.shape:", input_tensor.shape)
# print(input_tensor)

converted_input = input_tensor.permute(1, 2, 0, 3) # Реализуйте преобразование входа.
print("converted_input.shape:", converted_input.shape).contiguous().view(-1)
print(converted_input)

(2, 3, 1, 1)
input_tensor.shape: torch.Size([2, 3, 4, 4])
converted_input.shape: torch.Size([3, 4, 2, 4])
tensor([[[[ 0.,  1.,  2.,  3.],
          [48., 49., 50., 51.]],

         [[ 4.,  5.,  6.,  7.],
          [52., 53., 54., 55.]],

         [[ 8.,  9., 10., 11.],
          [56., 57., 58., 59.]],

         [[12., 13., 14., 15.],
          [60., 61., 62., 63.]]],


        [[[16., 17., 18., 19.],
          [64., 65., 66., 67.]],

         [[20., 21., 22., 23.],
          [68., 69., 70., 71.]],

         [[24., 25., 26., 27.],
          [72., 73., 74., 75.]],

         [[28., 29., 30., 31.],
          [76., 77., 78., 79.]]],


        [[[32., 33., 34., 35.],
          [80., 81., 82., 83.]],

         [[36., 37., 38., 39.],
          [84., 85., 86., 87.]],

         [[40., 41., 42., 43.],
          [88., 89., 90., 91.]],

         [[44., 45., 46., 47.],
          [92., 93., 94., 95.]]]])
