In [1]:
import numpy as np
import torch
from torch.nn.functional import conv2d as libConv2d

In [2]:
class conv2Self():
    def __init__(
        self,
        input_data,
        kernel_size: tuple | int,
        bias: float | None = None,
        stride: int = 1,
        padding: tuple[int, int] | int | str = (0, 0),
        dilation: int = 1,
    ):
        self.input_data = input_data[0,0].numpy()
        self.input_data_for_torch = input_data
        self.bias = bias

        if type(kernel_size) == tuple:
            self.kernel_size = kernel_size
        else:
            self.kernel_size = (kernel_size, kernel_size)

        self.stride = stride

        self.dilation = dilation

        if type(padding) == tuple:
            self.padding = padding[0]
        elif padding == "same":
            if self.stride != 1:
                raise ValueError("padding 'same' work only with stride=1")
            self.padding = self.kernel_size[0]-1
        elif padding == "valid":
            self.padding = 0
        else:
            self.padding = (padding, padding)

        self.weight_tensor_for_torch = torch.randn(1,1,self.kernel_size[0], self.kernel_size[1])
        self.weight_tensor = self.weight_tensor_for_torch[0,0].numpy()

    def conv2d(self):
        image_height, image_width = self.input_data.shape
        weight_height, weight_width = self.weight_tensor.shape

        H_out = int((image_height - self.dilation * (weight_height - 1) - 1 + 2* self.padding)/self.stride) + 1
        W_out = int((image_width - self.dilation * (weight_width - 1) - 1 + 2* self.padding)/self.stride) + 1


        if self.padding>0:
            self.input_data = np.pad(self.input_data, self.padding, mode='constant')

        result = np.zeros((H_out, W_out))

        for y in range(H_out):
            for x in range(W_out):
                input_slice = self.input_data[y * self.stride:y * self.stride + weight_height, x * self.stride:x * self.stride + weight_width]
                result[y, x] = np.sum(input_slice * self.weight_tensor)

        if self.bias:
            result+=self.bias

        return result

    def torch_conv2d(self):
        return libConv2d(self.input_data_for_torch, self.weight_tensor_for_torch, self.bias, self.stride, self.padding, self.dilation)

    def test(self, print_flg=False):
        my_conv2d = self.conv2d()
        torch_out = np.array(self.torch_conv2d())
        if print_flg:
            print(my_conv2d)
            print(torch_out[0,0])
        print(np.allclose(my_conv2d, torch_out[0, 0]))

In [3]:
image = torch.randn(1,1,5,5)
c1 = conv2Self(image, kernel_size=1)
c1.test()

True


In [4]:
image = torch.randn(1,1,5,5)
c2 = conv2Self(image, kernel_size=1, padding='valid')
c2.test()

True


In [5]:
image = torch.randn(1,1,5,5)
c3 = conv2Self(image, kernel_size=1, padding='same')
c3.test()

True


In [6]:
image = torch.randn(1,1,5,5)
c4 = conv2Self(image, kernel_size=4, padding='same')
c4.test()

True


In [7]:
image = torch.randn(1,1,5,5)
c5 = conv2Self(image, kernel_size=1, dilation=3)
c5.test()

True


In [8]:
image = torch.randn(1,1,5,5)
c6 = conv2Self(image, kernel_size=1, stride=4)
c6.test()

True
