In [79]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
from triton.runtime import driver

In [2]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [4]:
properties = driver.active.utils.get_device_properties(DEVICE.index)

In [5]:
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [112]:
@triton.jit
def conv2d_kernel(input_ptr: torch.Tensor, kernel_ptr: torch.Tensor, output: torch.Tensor, 
           height: int, width: int,  stride: int, kH: int, kW: int, max_kH: tl.constexpr, max_kW: tl.constexpr, BLOCK_SIZE: tl.constexpr): 
    
    tile_row = tl.program_id(0)
    tile_col = tl.program_id(1)
    
    out_height = (height - kH) // stride + 1
    out_width = (width - kW) // stride + 1
    
    OUT_TILE_HEIGHT = BLOCK_SIZE - kH + 1
    OUT_TILE_WIDTH = BLOCK_SIZE - kW + 1
    
    out_tile_row = tile_row * OUT_TILE_HEIGHT
    out_tile_col = tile_col * OUT_TILE_WIDTH
    
    kernel_row_offset = tl.arange(0, max_kH)
    
    kernel_col_offset = tl.arange(0, max_kW)

    # Create a mask for valid kernel indices.
    kernel_mask = (kernel_row_offset[:, None] < kH) & (kernel_col_offset[None, :] < kW)
    # Load a full kernel block using masked load.
    kernel_block = tl.load(
        kernel_ptr + kernel_row_offset[:, None] * kW + kernel_col_offset[None, :],
        mask=kernel_mask, other=0.0
    )
    
    
    for i in tl.range(0, OUT_TILE_HEIGHT):
        for j in tl.range(0, OUT_TILE_WIDTH):
            # acc = 0.0
            out_row = out_tile_row + i
            out_col = out_tile_col + j
            if out_row < out_height and out_col < out_width:
                input_row = out_row * stride
                input_col = out_col * stride
                
                patch_ptr = input_ptr + input_row * width + input_col
                input_patch = patch_ptr + kernel_row_offset[:, None] * width + kernel_col_offset[None, :]
                mask = (input_row + kernel_row_offset[:, None] < height) & (input_col + kernel_col_offset[None, :] < width) & kernel_mask
                input_block = tl.load(input_patch, mask = mask, other=0.0)
                acc = tl.sum(input_block * kernel_block)
                tl.store(output + out_row * out_width + out_col, acc)
                
                
    

In [124]:
# Example dimensions and hyperparameters.
height, width = 512, 512     # Input dimensions.
kH, kW = 3, 3                # Kernel dimensions.
stride = 1
BLOCK_SIZE = 16              # Choose a BLOCK_SIZE such that BLOCK_SIZE > kH and BLOCK_SIZE > kW.

# Create dummy tensors.
input_tensor = torch.randn(height, width, device='cuda', dtype=torch.float32)
kernel_tensor = torch.randn(kH, kW, device='cuda', dtype=torch.float32)
out_height = (height - kH) // stride + 1
out_width = (width - kW) // stride + 1
output_tensor = torch.empty(out_height, out_width, device='cuda', dtype=torch.float32)


In [125]:
# Compute grid dimensions based on tile size.
OUT_TILE_HEIGHT = BLOCK_SIZE - kH + 1
OUT_TILE_WIDTH  = BLOCK_SIZE - kW + 1
num_tile_rows = (out_height + OUT_TILE_HEIGHT - 1) // OUT_TILE_HEIGHT
num_tile_cols = (out_width + OUT_TILE_WIDTH - 1) // OUT_TILE_WIDTH

# The grid is 2D.
grid = (num_tile_rows, num_tile_cols)
max_kH, max_kW = triton.next_power_of_2(kH), triton.next_power_of_2(kW)


In [126]:
conv2d_kernel[grid](input_tensor, kernel_tensor, output_tensor,
    height, width, stride, kH, kW, max_kH, max_kW, BLOCK_SIZE)

torch_output = F.conv2d(input_tensor[None, None, :, :], kernel_tensor[None, None, :, :]).squeeze()
torch.allclose(output_tensor, torch_output, rtol=1e-5, atol=1e-6)

True

In [116]:
torch_output

tensor([[-5.0729,  2.0847, -2.1746,  ..., -3.2092,  2.9327, -0.4276],
        [ 0.6600,  1.2153, -6.4636,  ..., -1.2946,  3.0366, -7.5930],
        [-0.2889, -1.5446, -0.2909,  ...,  3.0190,  2.6954, -3.4862],
        ...,
        [ 3.8405,  0.9267, -3.0867,  ..., -5.4856,  2.1411,  3.7709],
        [ 1.4122, -2.3527, -3.4199,  ..., -0.9416, -2.9081,  2.3585],
        [ 1.0550, -0.7478,  1.3724,  ...,  0.2417, -6.5702,  3.9521]],
       device='cuda:0')

In [117]:
output_tensor

tensor([[-5.0729,  2.0847, -2.1746,  ..., -3.2092,  2.9327, -0.4276],
        [ 0.6600,  1.2153, -6.4636,  ..., -1.2946,  3.0366, -7.5930],
        [-0.2889, -1.5446, -0.2909,  ...,  3.0190,  2.6954, -3.4862],
        ...,
        [ 3.8405,  0.9267, -3.0867,  ..., -5.4856,  2.1411,  3.7709],
        [ 1.4122, -2.3527, -3.4199,  ..., -0.9416, -2.9081,  2.3585],
        [ 1.0550, -0.7478,  1.3724,  ...,  0.2417, -6.5702,  3.9521]],
       device='cuda:0')

In [119]:
output_tensor.shape

torch.Size([32, 32])

In [120]:
torch_output.shape

torch.Size([32, 32])

In [122]:
mismatch_indices = torch.where(output_tensor != torch_output)

# Print some of the mismatches to analyze the pattern
for i in range(min(10, len(mismatch_indices[0]))):
    row, col = mismatch_indices[0][i], mismatch_indices[1][i]
    print(f"Mismatch at position ({row}, {col}):")
    print(f"  Your output: {output_tensor[row, col]}")
    print(f"  PyTorch output: {torch_output[row, col]}")
    print(f"  Difference: {output_tensor[row, col] - torch_output[row, col]}")

Mismatch at position (0, 0):
  Your output: -5.0729498863220215
  PyTorch output: -5.072949409484863
  Difference: -4.76837158203125e-07
Mismatch at position (0, 2):
  Your output: -2.1745619773864746
  PyTorch output: -2.1745622158050537
  Difference: 2.384185791015625e-07
Mismatch at position (0, 3):
  Your output: -0.8066891431808472
  PyTorch output: -0.8066893815994263
  Difference: 2.384185791015625e-07
Mismatch at position (0, 4):
  Your output: 6.893205642700195
  PyTorch output: 6.8932061195373535
  Difference: -4.76837158203125e-07
Mismatch at position (0, 5):
  Your output: 0.8722105026245117
  PyTorch output: 0.8722104430198669
  Difference: 5.960464477539063e-08
Mismatch at position (0, 6):
  Your output: -3.0692574977874756
  PyTorch output: -3.0692572593688965
  Difference: -2.384185791015625e-07
Mismatch at position (0, 7):
  Your output: 1.5042240619659424
  PyTorch output: 1.5042235851287842
  Difference: 4.76837158203125e-07
Mismatch at position (0, 8):
  Your output