In [79]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
from triton.runtime import driver

In [2]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [4]:
properties = driver.active.utils.get_device_properties(DEVICE.index)

In [5]:
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [None]:
@triton.jit
def conv2d_kernel(input_ptr: torch.Tensor, kernel_ptr: torch.Tensor, output: torch.Tensor, 
           height: int, width: int,  stride: int, kH: int, kW: int, max_kH: tl.constexpr, max_kW: tl.constexpr, BLOCK_SIZE: tl.constexpr): 
    
    tile_row = tl.program_id(0)
    tile_col = tl.program_id(1)
    
    out_height = (height - kH) // stride + 1
    out_width = (width - kW) // stride + 1
    
    OUT_TILE_HEIGHT = BLOCK_SIZE - kH + 1
    OUT_TILE_WIDTH = BLOCK_SIZE - kW + 1
    
    out_tile_row = tile_row * OUT_TILE_HEIGHT
    out_tile_col = tile_col * OUT_TILE_WIDTH
    
    kernel_row_offset = tl.arange(0, max_kH)
    
    kernel_col_offset = tl.arange(0, max_kW)

    # Create a mask for valid kernel indices.
    kernel_mask = (kernel_row_offset[:, None] < kH) & (kernel_col_offset[None, :] < kW)
    # Load a full kernel block using masked load.
    kernel_block = tl.load(
        kernel_ptr + kernel_row_offset[:, None] * kW + kernel_col_offset[None, :],
        mask=kernel_mask, other=0.0
    )
    
    
    for i in tl.range(0, OUT_TILE_HEIGHT):
        for j in tl.range(0, OUT_TILE_WIDTH):
            # acc = 0.0
            out_row = out_tile_row + i
            out_col = out_tile_col + j
            if out_row < out_height and out_col < out_width:
                input_row = out_row * stride
                input_col = out_col * stride
                
                patch_ptr = input_ptr + input_row * width + input_col
                input_patch = patch_ptr + kernel_row_offset[:, None] * width + kernel_col_offset[None, :]
                mask = (input_row + kernel_row_offset[:, None] < height) & (input_col + kernel_col_offset[None, :] < width)
                input_block = tl.load(input_patch, mask = mask, other=0.0)
                acc = tl.sum(input_block * kernel_block)
                tl.store(output + out_row * out_width + out_col, acc)
                
                
    

In [105]:
# Example dimensions and hyperparameters.
height, width = 34, 34     # Input dimensions.
kH, kW = 3, 3                # Kernel dimensions.
stride = 1
BLOCK_SIZE = 16              # Choose a BLOCK_SIZE such that BLOCK_SIZE > kH and BLOCK_SIZE > kW.

# Create dummy tensors.
input_tensor = torch.randn(height, width, device='cuda', dtype=torch.float32)
kernel_tensor = torch.randn(kH, kW, device='cuda', dtype=torch.float32)
out_height = (height - kH) // stride + 1
out_width = (width - kW) // stride + 1
output_tensor = torch.empty(out_height, out_width, device='cuda', dtype=torch.float32)


In [106]:
# Compute grid dimensions based on tile size.
OUT_TILE_HEIGHT = BLOCK_SIZE - kH + 1
OUT_TILE_WIDTH  = BLOCK_SIZE - kW + 1
num_tile_rows = (out_height + OUT_TILE_HEIGHT - 1) // OUT_TILE_HEIGHT
num_tile_cols = (out_width + OUT_TILE_WIDTH - 1) // OUT_TILE_WIDTH

# The grid is 2D.
grid = (num_tile_rows, num_tile_cols)
max_kH, max_kW = triton.next_power_of_2(kH), triton.next_power_of_2(kW)


In [107]:
conv2d_kernel[grid](input_tensor, kernel_tensor, output_tensor,
    height, width, stride, kH, kW, max_kH, max_kW, BLOCK_SIZE)

torch_output = F.conv2d(input_tensor[None, None, :, :], kernel_tensor[None, None, :, :]).squeeze()
torch.allclose(output_tensor, torch_output)

False