In [12]:
import torch
import numpy as np
from torch.nn.functional import conv3d as libConv3d
import unittest

class Conv3D:
    def __init__(self, input_data, in_channels: int, out_channels: int, kernel_size: tuple | int,
                 bias: float | None = None, stride: int = 1, padding: tuple[int, int, int] | int | str = (0, 0, 0),
                 dilation: int = 1):
        self.input_data_numpy = input_data.numpy()
        self.input_data_torch = input_data
        self.bias = bias

        self.in_channels, self.out_channels = in_channels, out_channels
        self.kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        self.stride = (stride, stride, stride) if isinstance(stride, int) else stride
        self.dilation = (dilation, dilation, dilation) if isinstance(dilation, int) else dilation

        self.padding = self._parse_padding(padding)
        
        self.weight_tensor_torch = torch.randn(1, 1, *self.kernel_size)
        self.weight_tensor_numpy = self.weight_tensor_torch.numpy()

    def _parse_padding(self, padding):
        if isinstance(padding, tuple):
            return padding
        elif padding == "same":
            if any(s != 1 for s in self.stride):
                raise ValueError("padding 'same' works only with stride=1")
            return tuple(size - 1 for size in self.kernel_size)
        elif padding == "valid":
            return (0, 0, 0)
        else:
            return (padding, padding, padding)
        
    def conv3d(self):
        batches, _, d_in, h_in, w_in = self.input_data_numpy.shape
        out = []

        for b in range(batches):
            d_out = self._output_size(d_in, self.kernel_size[0], self.padding[0], self.stride[0], self.dilation[0])
            h_out = self._output_size(h_in, self.kernel_size[1], self.padding[1], self.stride[1], self.dilation[1])
            w_out = self._output_size(w_in, self.kernel_size[2], self.padding[2], self.stride[2], self.dilation[2])

            out.append(np.zeros((self.out_channels, d_out, h_out, w_out)))

            for c_out in range(self.out_channels):
                for z_out in range(d_out):
                    for y_out in range(h_out):
                        for x_out in range(w_out):
                            sum = 0
                            for c_in in range(self.in_channels):
                                for kernel_z in range(self.kernel_size[0]):
                                    for kernel_y in range(self.kernel_size[1]):
                                        for kernel_x in range(self.kernel_size[2]):
                                            z_in = z_out * self.stride[0] + kernel_z * self.dilation[0] - self.padding[0]
                                            y_in = y_out * self.stride[1] + kernel_y * self.dilation[1] - self.padding[1]
                                            x_in = x_out * self.stride[2] + kernel_x * self.dilation[2] - self.padding[2]
                                            if 0 <= z_in < d_in and 0 <= y_in < h_in and 0 <= x_in < w_in:
                                                sum += self.input_data_numpy[b, c_in, z_in, y_in, x_in] * \
                                                       self.weight_tensor_numpy[0, c_out, kernel_z, kernel_y, kernel_x]

                            out[b][c_out][z_out][y_out][x_out] = sum + (self.bias if self.bias else 0)

        return np.array(out)
        
    def _output_size(self, input_size, kernel_size, padding, stride, dilation):
        return int((input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)

    def torch_conv3d(self):
        return libConv3d(self.input_data_torch, self.weight_tensor_torch, bias=torch.tensor([self.bias]),
                         stride=self.stride, padding=self.padding, dilation=self.dilation)
    def test(self, print_flag=False):
        my_conv3d = self.conv3d()
        torch_out = self.torch_conv3d().squeeze().detach().numpy()
        if print_flag:
            print(my_conv3d[0][0])
            print(torch_out)
        print(np.allclose(my_conv3d[0][0], torch_out))

#тесты
class TestConv3D(unittest.TestCase):
    def test_case_1(self):
        input_data = torch.randn(1, 1, 5, 5, 5)
        conv3d_layer = Conv3D(input_data, in_channels=1, out_channels=1, kernel_size=4, bias=0.5, stride=1, padding=2, dilation=1)
        conv3d_layer.test(print_flag=False)

    def test_case_2(self):
        input_data = torch.randn(1, 1, 5, 5, 5)
        conv3d_layer = Conv3D(input_data, in_channels=1, out_channels=1, kernel_size=4, bias=0.5, stride=1, padding=1, dilation=2)
        conv3d_layer.test(print_flag=False)

    def test_case_3(self):
        input_data = torch.randn(1, 1, 5, 5, 5)
        conv3d_layer = Conv3D(input_data, in_channels=1, out_channels=1, kernel_size=4, bias=0.5, stride=4, padding=1, dilation=1)
        conv3d_layer.test(print_flag=False)

    def test_case_4(self):
        input_data = torch.randn(1, 1, 5, 5, 5)
        conv3d_layer = Conv3D(input_data, in_channels=1, out_channels=1, kernel_size=4, bias=0.5, stride=1, padding=1, dilation=1)
        conv3d_layer.test(print_flag=False)

suite = unittest.TestLoader().loadTestsFromTestCase(TestConv3D)
unittest.TextTestRunner(verbosity=2).run(suite)


test_case_1 (__main__.TestConv3D) ... ok
test_case_2 (__main__.TestConv3D) ... ok
test_case_3 (__main__.TestConv3D) ... ok
test_case_4 (__main__.TestConv3D) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.050s

OK


True
True
True
True


<unittest.runner.TextTestResult run=4 errors=0 failures=0>