In [141]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
from triton.runtime import driver

In [142]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [143]:
properties = driver.active.utils.get_device_properties(DEVICE.index)

In [144]:
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [145]:
def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE': 16}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE': 32}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE': 16}, num_stages=4, num_warps=8),
        triton.Config({'BLOCK_SIZE': 32}, num_stages=4, num_warps=8),                 
    ]

In [154]:
@triton.jit
def conv2d_kernel(input_ptr: torch.Tensor, kernel_ptr: torch.Tensor, output: torch.Tensor, 
           height: int, width: int,  stride: int, kH: int, kW: int, 
           max_kH: tl.constexpr, max_kW: tl.constexpr, BLOCK_SIZE: tl.constexpr): 
    
    tile_row = tl.program_id(0)
    tile_col = tl.program_id(1)
    
    out_height = (height - kH) // stride + 1
    out_width = (width - kW) // stride + 1
    
    OUT_TILE_HEIGHT = BLOCK_SIZE - kH + 1
    OUT_TILE_WIDTH = BLOCK_SIZE - kW + 1
    
    out_tile_row = tile_row * OUT_TILE_HEIGHT
    out_tile_col = tile_col * OUT_TILE_WIDTH
    
    kernel_row_offset = tl.arange(0, max_kH)
    
    kernel_col_offset = tl.arange(0, max_kW)

    # Create a mask for valid kernel indices.
    kernel_mask = (kernel_row_offset[:, None] < kH) & (kernel_col_offset[None, :] < kW)
    # Load a full kernel block using masked load.
    kernel_block = tl.load(
        kernel_ptr + kernel_row_offset[:, None] * kW + kernel_col_offset[None, :],
        mask=kernel_mask, other=0.0
    )
    
    
    for i in tl.range(0, OUT_TILE_HEIGHT):
        for j in tl.range(0, OUT_TILE_WIDTH):
            # acc = 0.0
            out_row = out_tile_row + i
            out_col = out_tile_col + j
            if out_row < out_height and out_col < out_width:
                input_row = out_row * stride
                input_col = out_col * stride
                
                patch_ptr = input_ptr + input_row * width + input_col
                input_patch = patch_ptr + kernel_row_offset[:, None] * width + kernel_col_offset[None, :]
                mask = (input_row + kernel_row_offset[:, None] < height) & (input_col + kernel_col_offset[None, :] < width) & kernel_mask
                input_block = tl.load(input_patch, mask = mask, other=0.0)
                acc = tl.sum(input_block * kernel_block)
                tl.store(output + out_row * out_width + out_col, acc)
                
                
    

In [155]:
def conv2d(height: int, width: int, kH: int, kW: int, stride: int, BLOCK_SIZE: int):
    # Create dummy tensors.
    input_tensor = torch.randn(height, width, device='cuda', dtype=torch.float32)
    kernel_tensor = torch.randn(kH, kW, device='cuda', dtype=torch.float32)
    out_height = (height - kH) // stride + 1
    out_width = (width - kW) // stride + 1
    output_tensor = torch.empty(out_height, out_width, device='cuda', dtype=torch.float32)
    # Compute grid dimensions based on tile size.
    OUT_TILE_HEIGHT = BLOCK_SIZE - kH + 1
    OUT_TILE_WIDTH  = BLOCK_SIZE - kW + 1
    num_tile_rows = (out_height + OUT_TILE_HEIGHT - 1) // OUT_TILE_HEIGHT
    num_tile_cols = (out_width + OUT_TILE_WIDTH - 1) // OUT_TILE_WIDTH

    # The grid is 2D.
    grid = (num_tile_rows, num_tile_cols)
    max_kH, max_kW = triton.next_power_of_2(kH), triton.next_power_of_2(kW)
    
    conv2d_kernel[grid](input_tensor, kernel_tensor, output_tensor,
    height, width, stride, kH, kW, max_kH, max_kW, BLOCK_SIZE=BLOCK_SIZE)
    
    return output_tensor


In [156]:
# Example dimensions and hyperparameters.
height, width = 512, 512     # Input dimensions.
kH, kW = 3, 3                # Kernel dimensions.
stride = 1
BLOCK_SIZE = 16              # Choose a BLOCK_SIZE such that BLOCK_SIZE > kH and BLOCK_SIZE > kW.

output_tensor = conv2d(height=height, width=width, kH=kH, kW=kW, stride=stride, BLOCK_SIZE=BLOCK_SIZE)
torch_output = F.conv2d(input_tensor[None, None, :, :], kernel_tensor[None, None, :, :]).squeeze()
if torch.allclose(output_tensor, torch_output, rtol=1e-5, atol=1e-6):
    print("✅ Triton and Torch conv2d implementations match")
else:
    print("❌ Triton and Torch conv2d implementations differ")

❌ Triton and Torch conv2d implementations differ


In [149]:
torch_output

tensor([[ 0.1862, -2.3624,  1.5896,  ..., -3.7066,  1.3915, -0.1006],
        [-1.4170, -3.6616, -3.4478,  ...,  2.1552,  0.9247,  0.5250],
        [-3.0319,  5.6185, -1.8613,  ...,  1.0411,  0.4177, -2.6132],
        ...,
        [-3.3992,  6.8019,  3.3175,  ...,  0.6340,  0.2065, -0.1600],
        [ 1.6621,  3.1356,  4.1521,  ...,  0.9306, -0.6480,  0.1893],
        [ 3.5988,  1.5337,  1.2764,  ..., -0.9792,  1.9511,  1.3613]],
       device='cuda:0')

In [150]:
output_tensor

tensor([[-2.4613,  4.0595, -1.2003,  ...,  3.5653,  2.1818, -2.1646],
        [ 2.2217, -1.0708, -4.5527,  ...,  0.4244,  0.4346,  3.8623],
        [ 0.5882,  1.0266, -2.6549,  ..., -0.3509, -3.8504,  4.2345],
        ...,
        [-1.3854,  2.0151, -1.3444,  ..., -1.4743, -2.7378,  0.3158],
        [-0.3509,  0.1170,  1.7632,  ...,  0.3835, -3.3642,  1.6686],
        [-0.1533,  1.1206, -1.5568,  ...,  2.1246, -2.5077, -3.0377]],
       device='cuda:0')

In [None]:
# Configurate the benchmarks
configs = []
ref_lib = "torch"
kernel_sizes = [3, 5, 7, 11]

for ksz in seq_lengths:
    configs.append(
        triton.testing.Benchmark(
            x_names=["image_size"],  # Argument names to use as an x-axis for the plot
            x_vals=[128, 256, 512],  # Different possible values for head_dim
            line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
            line_vals=[ref_lib.lower(), "triton"],  # Label name for the lines
            line_names=[ref_lib, "Triton"],  # Line styles
            styles=[("green", "-"), ("blue", "-")],
            ylabel="GB/s",  # Label name for the y-axis
            plot_name=f"conv2d-performance-seq{ksz}-fp32",  # Name for the plot
            args={"seq_len": ksz, "batch_size": 32, "num_heads": 32},
        ))

@triton.testing.perf_report(configs)
def benchmark(head_dim, provider, batch_size, seq_len, num_heads):
    # Generate test data
    q = torch.randn((batch_size, seq_len, num_heads, head_dim), device=DEVICE, dtype=torch.float32)
    k = torch.randn((batch_size, seq_len, num_heads, head_dim), device=DEVICE, dtype=torch.float32)
    
    # Position IDs
    position_ids = torch.arange(seq_len, device=DEVICE).unsqueeze(0).repeat(batch_size, 1)
    
    # Precompute cosines and sines
    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=DEVICE).float() / head_dim))
    
    # Initialize cosine and sine tables
    cos_table = torch.zeros((seq_len, head_dim // 2), device=DEVICE)
    sin_table = torch.zeros((seq_len, head_dim // 2), device=DEVICE)
    
    for pos in range(seq_len):
        freqs = pos * inv_freq
        cos_table[pos] = freqs.cos()
        sin_table[pos] = freqs.sin()
    
    quantiles = [0.5, 0.2, 0.8]
    if provider == ref_lib.lower():
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch_apply_rope(q, k, position_ids, head_dim=head_dim, cos=cos_table, sin=sin_table), 
            quantiles=quantiles
        )
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: apply_rope(q, k, position_ids, cos_table, sin_table),
            quantiles=quantiles
        )
    
    # Calculate memory bandwidth
    # For RoPE:
    # - Read Q and K tensors
    # - Read cosine and sine tables for each position
    # - Write output Q and K tensors
    bytes_per_element = 4  # float32
    
    # Input tensors size: batch_size * seq_len * num_heads * head_dim * 2 (for Q and K)
    input_size = batch_size * seq_len * num_heads * head_dim * 2
    
    # Lookup tables size: seq_len * (head_dim // 2) * 2 (for cos and sin)
    lookup_size = seq_len * (head_dim // 2) * 2
    
    # Position ids size: batch_size * seq_len
    pos_ids_size = batch_size * seq_len
    
    # Output tensors size: same as input
    output_size = input_size
    
    # Total bytes accessed
    bytes_accessed = (input_size + lookup_size + pos_ids_size + output_size) * bytes_per_element
    
    # Convert to GB/s
    gb_per_s = lambda ms: bytes_accessed * 1e-9 / (ms * 1e-3)
    
    return gb_per_s(ms), gb_per_s(max_ms), gb_per_s(min_ms)

# Run the benchmark
print(benchmark.run(show_plots=True, print_data=True))