In [1]:
import torch

In [2]:
from itertools import product

def unroll(X: torch.Tensor, k: int, h_out: int, w_out: int):
    """
    Unroll input values into the shape required by weight dims and output dims.
    Assume X is only a single sample.
    """
    C = X.shape[0]
    X_unrolled = torch.empty((C*k*k, h_out*w_out), dtype=int)

    # think of the iteration as the stamping out of the conv mask
    for c, p, q, h, w in product(*(range(C), range(k), range(k), range(h_out), range(w_out))):
        
        # the c*k*k tells you which channel you're currently on in the column
        # p*k + q tells you which value you're in for a single weight mas
        h_unroll = c*k*k + p*k + q
        # w_unroll tells you which column you're operating on, which corresponds to 
        # the unrolled index into the output matrix
        w_unroll = h * w_out + w
        X_unrolled[h_unroll, w_unroll] = X[c, h+p, w+q]

    return X_unrolled







def conv2d(X: torch.Tensor, W_unroll: torch.Tensor, k: int):
    # X is (N, C, H, W)
    N, C, H, W = X.shape

    # W is (M, C*k*k)
    M = W_unroll.shape[0]

    h_out = H - k + 1
    w_out = W - k + 1
    # don't forget the extra first batch size dim
    Y = torch.empty((N, M, h_out*w_out))

    for n in range(N):
        X_unrolled = unroll(X[n], k, h_out, w_out)
        Y[n] = W_unroll @ X_unrolled

    return Y



In [29]:
w_unroll = torch.tensor([[1, 1, 2, 2, 1, 1, 1, 1, 0, 1, 1, 0],
        [1, 0, 0, 1, 2, 1, 2, 1, 1, 2, 2, 0]], dtype=torch.float32)

X = torch.tensor([[[1, 2, 0],
         [1, 1, 3],
         [0, 2, 2]],

        [[0, 2, 1],
         [0, 3, 2],
         [1, 1, 0]],

        [[1, 2, 1],
         [0, 1, 3],
         [3, 3, 2]]], dtype=torch.float32, device='cuda')

unroll(X, 2, 2, 2)

# unsqueeze for the batch size of 1
# conv2d(X.unsqueeze(0), w_unroll, 2)

tensor([[1, 2, 1, 1],
        [2, 0, 1, 3],
        [1, 1, 0, 2],
        [1, 3, 2, 2],
        [0, 2, 0, 3],
        [2, 1, 3, 2],
        [0, 3, 1, 1],
        [3, 2, 1, 0],
        [1, 2, 0, 1],
        [2, 1, 1, 3],
        [0, 1, 3, 3],
        [1, 3, 3, 2]])

In [35]:
from torch.utils.cpp_extension import load
module = load(
    name='m',
    sources=['main.cpp', 'cnn2d.cu'],
    verbose=True
)

Using /home/seb/.cache/torch_extensions/py312_cu121 as PyTorch extensions root...
The input conditions for extension module m have changed. Bumping to version 10 and re-building as m_v10...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/seb/.cache/torch_extensions/py312_cu121/m/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module m_v10...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=m_v10 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/TH -isystem /home/seb/CUDA/cudaenv/lib/python3.12/site-packages/torch/include/THC -isystem /usr/local/cuda-12.3/include -isystem /home/seb/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/seb/CUDA/pmpp/cnn/main.cpp -o main.o 
[2/3] /usr/local/cuda-12.3/bin/nvcc --generate-dependencies-with-compile --dependency-output cnn2d.cuda.o.d -DTORCH_EXTENSION_NAME=m_v10 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/seb/CUDA/cu

Loading extension module m_v10...


In [36]:
out = module.unroll(X, 2)
out

12 threads, 1 blocks, 


tensor([[1., 2., 1., 1.],
        [2., 0., 1., 3.],
        [1., 1., 0., 2.],
        [1., 3., 2., 2.],
        [0., 2., 0., 3.],
        [2., 1., 3., 2.],
        [0., 3., 1., 1.],
        [3., 2., 1., 0.],
        [1., 2., 0., 1.],
        [2., 1., 1., 3.],
        [0., 1., 3., 3.],
        [1., 3., 3., 2.]], device='cuda:0')

In [31]:
out.shape

torch.Size([12, 4])