In [1]:
import os, sys
sys.path.insert(0, '..')

In [2]:
from pathlib import Path
import torch
from torchvision.io import read_image, write_png
from torch.utils.cpp_extension import load_inline
from utils import cdiv

In [None]:
from numba import cuda
from numba.cuda import as_cuda_array as ca

<img width="500" src="../images/image.png" id="jupyter"/>

## Basic convolution kernel (without shared memory)

In [10]:
@cuda.jit
def conv2d_k(m, f, out, r):
    # get row and column indices
    row,col = cuda.grid(2)
    if row < out.shape[0] and col < out.shape[1]:  # Ensure threads are within output shape
        val = 0
        for i in range(f.shape[0]):
            for j in range(f.shape[1]):
                in_row = row - r + i
                in_col = col - r +j
                if (m.shape[0]>in_row >=0 and m.shape[1]>in_col >=0):
                    val += m[in_row, in_col] * f[i, j]  # Convolution operation
        out[row, col] = val  # Store result in output array


def conv_2d(m, f):
    h,w  = m.shape
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    # TOTAL block size is limited by 1024 threads
    block_size = 32
    blocks = cdiv(h,block_size), cdiv(w,block_size)
    conv2d_k[blocks, (block_size, block_size)](ca(m), ca(f), ca(out), f.shape[0]//2) 
    return out

In [18]:
conv = torch.nn.Conv2d(1, 1, 3, bias=False, padding=1).cuda()
m1 = torch.rand(1000, 2000).contiguous().cuda()
f = conv.weight[0][0].detach().contiguous().cuda()

In [19]:
torch.isclose(conv(m1[None,]), conv_2d(m1,f), atol=1e-7).all()

tensor(True, device='cuda:0')

In [20]:
%timeit conv_2d(m1,f)

487 µs ± 183 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
%timeit with torch.no_grad(): conv(m1[None,])

114 µs ± 413 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## With constant memory

In [None]:
code = '''
#include <cuda_runtime.h>

__global__ void conv2d_k(float *m, float *f, float *out, int r, int m_h, int m_w, int f_h, int f_w, int out_h, int out_w) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < out_h && col < out_w) {
        float val = 0;
        for (int i = 0; i < f_h; ++i) {
            for (int j = 0; j < f_w; ++j) {
                int in_row = row - r + i;
                int in_col = col - r + j;
                if (in_row >= 0 && in_row < m_h && in_col >= 0 && in_col < m_w) {
                    val += m[in_row * m_w + in_col] * f[i * f_w + j];
                }
            }
        }
        out[row * out_w + col] = val;
    }
}

void conv_2d(float *m, float *f, float *out, int m_h, int m_w, int f_h, int f_w, int out_h, int out_w) {
    // Define block and grid dimensions
    dim3 blockDim(32, 32);
    dim3 gridDim((out_w + blockDim.x - 1) / blockDim.x, (out_h + blockDim.y - 1) / blockDim.y);

    // Launch kernel
    conv2d_k<<<gridDim, blockDim>>>(m, f, out, f_h / 2, m_h, m_w, f_h, f_w, out_h, out_w);
}
'''

In [36]:
@cuda.jit
def conv2d_const_k(m, out, r):
    # get row and column indices
    row,col = cuda.grid(2)
    if row < out.shape[0] and col < out.shape[1]:  # Ensure threads are within output shape
        val = 0
        for i in range(f.shape[0]):
            for j in range(f.shape[1]):
                in_row = row - r + i
                in_col = col - r +j
                if (m.shape[0]>in_row >=0 and m.shape[1]>in_col >=0):
                    val += m[in_row, in_col] * f[i, j]  # Convolution operation
        out[row, col] = val  # Store result in output array


def conv_const_2d(m, f):
    h,w  = m.shape
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    # TOTAL block size is limited by 1024 threads
    block_size = 32
    f = cuda.const.array_like(filter)
    blocks = cdiv(h,block_size), cdiv(w,block_size)
    conv2d_const_k[blocks, (block_size, block_size)](ca(m), ca(out), f.shape[0]//2) 
    return out

In [37]:
torch.isclose(conv(m1[None,]), conv_const_2d(m1,f), atol=1e-7).all()

NotImplementedError: <function const.array_like at 0x7f6115b818a0> cannot be called from host code

In [35]:
%timeit conv_const_2d(m1,f)

621 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
