In [1]:
import os, sys
sys.path.insert(0, '..')

In [2]:
from pathlib import Path
import torch
from torchvision.io import read_image, write_png
from torch.utils.cpp_extension import load_inline
from profiling.profiler import profile
from utils import *

In [4]:
from numba import cuda
from numba.cuda import as_cuda_array as ca

<img width="500" src="../images/image.png" id="jupyter"/>

In [38]:
k_size = 3
conv = torch.nn.Conv2d(1, 1, k_size, bias=False, padding=k_size//2).cuda()
m1 = torch.rand(1000, 2000).contiguous().cuda()
f = conv.weight[0][0].detach().contiguous().cuda()

## Basic convolution kernel (without shared memory)

Start from numba for debugging

In [5]:
@cuda.jit
def conv2d_k(m, f, out, r):
    # get row and column indices
    row,col = cuda.grid(2)
    if row < out.shape[0] and col < out.shape[1]:  # Ensure threads are within output shape
        val = 0
        for i in range(f.shape[0]):
            for j in range(f.shape[1]):
                in_row = row - r + i
                in_col = col - r +j
                if (m.shape[0]>in_row >=0 and m.shape[1]>in_col >=0):
                    val += m[in_row, in_col] * f[i, j]  # Convolution operation
        out[row, col] = val  # Store result in output array


def conv_2d(m, f):
    h,w  = m.shape
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    # TOTAL block size is limited by 1024 threads
    block_size = 32
    blocks = cdiv(h,block_size), cdiv(w,block_size)
    conv2d_k[blocks, (block_size, block_size)](ca(m), ca(f), ca(out), f.shape[0]//2) 
    return out

In [8]:
torch.isclose(conv(m1[None,]), conv_2d(m1,f), atol=1e-7).all()

tensor(True, device='cuda:0')

In [9]:
# %timeit conv_2d(m1,f)

In [46]:
%timeit with torch.no_grad(): conv(m1[None,])

113 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Now rewrite into CUDA kernel

In [28]:
cuda_src = cuda_begin + r'''
__global__ void conv2d_k(float* m, float* f, float* out, int f_size, int m_h, int m_w) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int r = f_size/2;

    if (row < m_h && col < m_w) {
        float val = 0;
        for (int i = 0; i < 2*r+1; i++) {
            for (int j = 0; j < 2*r+1; j++) {
                int in_row = row - r + i;
                int in_col = col - r + j;
                if (in_row >= 0 && in_row < m_h && in_col >= 0 && in_col < m_w) {
                    val += m[in_row * m_w + in_col] * f[i*f_size+j];
                }
            }
        }
        out[row * m_w + col] = val;
    }
}

torch::Tensor conv2d(torch::Tensor m, torch::Tensor f) {
    CHECK_INPUT(m); CHECK_INPUT(f);
    int h = m.size(0);
    int w = m.size(1);
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    conv2d_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), f.data_ptr<float>(), output.data_ptr<float>(), f.size(0), h, w);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''
fname = 'conv2d'

In [29]:
mod = load_cuda(cuda_src, get_sig(fname, cuda_src), [fname])

In [17]:
torch.isclose(conv(m1[None,]), mod.conv2d(m1,f)).all()

tensor(True, device='cuda:0')

We see that we're slightly slower than pytorch

In [18]:
%timeit mod.conv2d(m1,f)

156 µs ± 3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## With shared memory

Now we're going to use tiled convolution where we collaboratively store patches of matrix into shared memory and then reuse it later when computing convolution. Another way is to also load padding to cover whole edges but there are already high chances to hit L2 cache for big matrices

In [36]:
cuda_src = cuda_begin + r'''
#define TILE_DIM 16
#define FILTER_RADIUS 1
__constant__ float F_c[2*FILTER_RADIUS+1][2*FILTER_RADIUS+1];
__global__ void conv2d_k(float* m, float* out, int m_h, int m_w) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int r = FILTER_RADIUS;
    
    __shared__ float tile[TILE_DIM][TILE_DIM];
    if (row < m_h && col < m_w) {
        tile[threadIdx.y][threadIdx.x] = m[row * m_w + col];
    }
    else{
        tile[threadIdx.y][threadIdx.x] = 0.;
    }
    __syncthreads();
    if (row < m_h && col < m_w) {
        float val = 0;
        for (int i = 0; i < 2*r+1; i++) {
            for (int j = 0; j < 2*r+1; j++) {
                if(threadIdx.x-r+j>=0 && 
                   threadIdx.x-r+j<TILE_DIM && 
                   threadIdx.y-r+i>=0 && 
                   threadIdx.y-r+i<TILE_DIM){
                    val += tile[threadIdx.y+i-r][threadIdx.x+j-r] * F_c[i][j];
                }
                else{
                    int in_row = row - r + i;
                    int in_col = col - r + j;
                    if (in_row >= 0 && in_row < m_h && in_col >= 0 && in_col < m_w) {
                    val += m[in_row * m_w + in_col] * F_c[i][j];
                }
                }
            }
        }
        out[row * m_w + col] = val;
    }
}

torch::Tensor conv2d_shared(torch::Tensor m, torch::Tensor f) {
    CHECK_INPUT(m); CHECK_INPUT(f);
    TORCH_CHECK(f.size(0)==2*FILTER_RADIUS+1 && f.size(1)==2*FILTER_RADIUS+1, 
    "Filter size must be 2*FILTER_RADIUS+1 x 2*FILTER_RADIUS+1");
    int h = m.size(0);
    int w = m.size(1);
    auto output = torch::zeros({h, w}, m.options());
    cudaMemcpyToSymbol(F_c, f.data_ptr<float>(), (2*FILTER_RADIUS+1)*(2*FILTER_RADIUS+1)*sizeof(float));

    dim3 tpb(TILE_DIM,TILE_DIM);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    conv2d_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), output.data_ptr<float>(), h, w);
    CUDA_ERR(cudaGetLastError());
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''
fname = 'conv2d_shared'

In [37]:
mod_s = load_cuda(cuda_src, get_sig(fname, cuda_src), [fname])

In [39]:
torch.isclose(conv(m1[None,]), mod_s.conv2d_shared(m1,f)).all()

tensor(True, device='cuda:0')

Unexpectedly tiled convolution works slower than a naive one

In [40]:
%timeit mod.conv2d(m1, f)

158 µs ± 243 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [43]:
%timeit mod_s.conv2d_shared(m1, f)

183 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [44]:
from functools import partial
profile(partial(mod.conv2d,m1[None,]), f)

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
            ProfilerStep*        46.95%     292.000us        67.68%     421.000us     210.500us             2  
              aten::zeros         2.89%      18.000us        18.97%     118.000us      59.000us             2  
              aten::empty         4.66%      29.000us         4.66%      29.000us      14.500us             2  
              aten::zero_         1.61%      10.000us        11.41%      71.000us      35.500us             2  
              aten::fill_         3.54%      22.000us         9.81%      61.000us      30.500us             2  
         cudaLaunchKernel         8.04%      50.000us         8.04%      50.000us      12.500us         

STAGE:2024-04-15 23:31:10 38918:38918 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-04-15 23:31:10 38918:38918 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-15 23:31:10 38918:38918 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
