In [14]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
import math

In [4]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [6]:
properties = driver.active.utils.get_device_properties(DEVICE.index)

In [7]:
properties

{'max_shared_mem': 101376,
 'max_num_regs': 65536,
 'multiprocessor_count': 64,
 'warpSize': 32,
 'sm_clock_rate': 1695000,
 'mem_clock_rate': 8001000,
 'mem_bus_width': 384}

In [None]:
@triton.jit
def pack_even_odd(q_even: tl.tensor, q_odd: tl.tensor, BLOCK_SIZE: tl.constexpr, DTYPE: tl.constexpr):
    """
    Packs even and odd parts into a flattened 1D tensor.
    Assumes q_even and q_odd have shape (BLOCK_SIZE,).
    """
    # Create a temporary tensor with shape (BLOCK_SIZE, 2)
    q_temp = tl.zeros((BLOCK_SIZE, 2), dtype==DTYPE)
    # Loop over the BLOCK_SIZE dimension to assign even and odd values
    for i in range(BLOCK_SIZE):
        q_temp[i, 0] = q_even[i]
        q_temp[i, 1] = q_odd[i]
    # Flatten the tensor to shape (BLOCK_SIZE * 2,)
    return tl.reshape(q_temp, (-1,))

@triton.jit
def apply_rope_kernel(Q: tl.tensor, K: tl.tensor, 
                      Q_out: tl.tensor, K_out: tl.tensor,
                      pos_id: tl.tensor, cosines: tl.tensor, sines: tl.tensor,
                      seq_length: int, num_heads: int, head_dim: int,
                      batch_stride: int, seq_stride: int, head_stride: int, 
                      BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr, DTYPE: tl.constexpr):
    pid_batch = tl.program_id(0)
    pid_seq   = tl.program_id(1)
    pid_head  = tl.program_id(2)
    
    pos_ids = tl.load(pos_id + pid_batch * seq_length + pid_seq)
    base_offset = pid_batch * batch_stride + pid_seq * seq_stride + pid_head * head_stride
    Q_start_ptr = Q + base_offset
    K_start_ptr = K + base_offset
    Q_out_ptr   = Q_out + base_offset
    K_out_ptr   = K_out + base_offset
    
    for k in range(0, tl.cdiv(num_heads, BLOCK_SIZE * 2)):
        dim_offsets = k * (BLOCK_SIZE * 2) + tl.arange(0, BLOCK_SIZE * 2)
        mask = dim_offsets < head_dim
        
        Q_loaded = tl.load(Q_start_ptr + dim_offsets, mask=mask)
        K_loaded = tl.load(K_start_ptr + dim_offsets, mask=mask)
        
        # Split into even and odd parts using our computed indices manually
        q_even = tl.zeros((BLOCK_SIZE,), dtype=DTYPE)
        q_odd  = tl.zeros((BLOCK_SIZE,), dtype=DTYPE)
        k_even = tl.zeros((BLOCK_SIZE,), dtype=DTYPE)
        k_odd  = tl.zeros((BLOCK_SIZE,), dtype=DTYPE)
        
        # Load even and odd components manually
        for i in range(BLOCK_SIZE):
            even_idx = i * 2
            odd_idx = even_idx + 1
            q_even[i] = Q_loaded[even_idx]
            q_odd[i]  = Q_loaded[odd_idx]
            k_even[i] = K_loaded[even_idx]
            k_odd[i]  = K_loaded[odd_idx]
        
        # Compute cosine and sine offsets and load corresponding values
        cos_offset = cosines + pos_ids * (head_dim // 2)
        sin_offset = sines + pos_ids * (head_dim // 2)
        offsets = k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        cos_mask = offsets < (head_dim // 2)
        cos_values = tl.load(cos_offset + offsets, mask=cos_mask)
        sin_values = tl.load(sin_offset + offsets, mask=cos_mask)
        
        # Compute new values using the rotary transformation
        q_even_out = q_even * cos_values - q_odd * sin_values
        q_odd_out  = q_even * sin_values + q_odd * cos_values
        k_even_out = k_even * cos_values - k_odd * sin_values
        k_odd_out  = k_even * sin_values + k_odd * cos_values
        
        # Use the helper function to pack even and odd outputs into flat arrays.
        q_flat = pack_even_odd(q_even_out, q_odd_out, BLOCK_SIZE)
        k_flat = pack_even_odd(k_even_out, k_odd_out, BLOCK_SIZE)
        
        # Store the results
        tl.store(Q_out_ptr + dim_offsets, q_flat, mask=mask)
        tl.store(K_out_ptr + dim_offsets, k_flat, mask=mask)

In [None]:
def apply_rope(q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, cosines: torch.Tensor, sines: torch.Tensor):
    q_out = torch.zeros_like(q)
    k_out = torch.zeros_like(k)
    B, L, H, D = q.shape
    num_stages = 8
    BLOCK_SIZE = 32
    
    # Grid dimensions should be (batch, sequence_length, num_heads)
    grid = (B, L, H)
    
    apply_rope_kernel[grid](
        q, k, cosines, sines, pos_ids,  # Input tensors
        q_out, k_out,                   # Output tensors
        L, H, D,                        # Dimensions
        q.stride(0), q.stride(1), q.stride(2),  # Strides
        BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages,
        DTYPE=tl.float32
    )
    
    return q_out, k_out

In [33]:
# For comparing with native PyTorch implementation
def torch_apply_rope(q, k, position_ids, head_dim):
    # Create position embeddings
    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=q.device).float() / head_dim))
    
    # Get sin and cos values
    batch_size, seq_len = position_ids.shape
    freqs = torch.einsum("b l, d -> b l d", position_ids.float(), inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = emb.cos()[:, :, None, :].repeat(1, 1, q.shape[2], 1)
    sin = emb.sin()[:, :, None, :].repeat(1, 1, q.shape[2], 1)
    
    # Apply rotation
    q_out = torch.cat([-q[..., 1::2], q[..., ::2]], dim=-1).reshape(q.shape)
    k_out = torch.cat([-k[..., 1::2], k[..., ::2]], dim=-1).reshape(k.shape)
    
    q_out = q * cos + q_out * sin
    k_out = k * cos + k_out * sin
    
    return q_out, k_out

# Test function
def test_apply_rope():
    # Set up test data
    torch.manual_seed(0)
    device = torch.device('cuda')
    
    # Parameters
    batch_size = 2
    seq_len = 16
    num_heads = 4
    head_dim = 64
    
    # Generate test data
    q = torch.randn((batch_size, seq_len, num_heads, head_dim), device=device, dtype=torch.float32)
    k = torch.randn((batch_size, seq_len, num_heads, head_dim), device=device, dtype=torch.float32)
    
    # Position IDs
    position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1)
    
    # Precompute cosines and sines
    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    
    # Initialize cosine and sine tables
    cos_table = torch.zeros((seq_len, head_dim // 2), device=device)
    sin_table = torch.zeros((seq_len, head_dim // 2), device=device)
    
    for pos in range(seq_len):
        freqs = pos * inv_freq
        cos_table[pos] = freqs.cos()
        sin_table[pos] = freqs.sin()
    
    # Run your Triton version
    triton_q_out, triton_k_out = apply_rope(q, k, position_ids, cos_table, sin_table)
    
    # Run PyTorch reference
    torch_q_out, torch_k_out = torch_apply_rope(q, k, position_ids, head_dim)
    
    # Compare results
    q_match = torch.allclose(triton_q_out, torch_q_out, atol=1e-5, rtol=1e-5)
    k_match = torch.allclose(triton_k_out, torch_k_out, atol=1e-5, rtol=1e-5)
    
    if q_match and k_match:
        print("✅ Triton and Torch RoPE implementations match")
    else:
        print("❌ Triton and Torch RoPE implementations differ")
        if not q_match:
            print(f"Q max difference: {(triton_q_out - torch_q_out).abs().max().item()}")
        if not k_match:
            print(f"K max difference: {(triton_k_out - torch_k_out).abs().max().item()}")

# Run the test
test_apply_rope()

CompilationError: at 26:17:
    Q_out_ptr   = Q_out + base_offset
    K_out_ptr   = K_out + base_offset

    for k in range(0, tl.cdiv(num_heads, BLOCK_SIZE * 2)):
        dim_offsets = k * (BLOCK_SIZE * 2) + tl.arange(0, BLOCK_SIZE * 2)
        mask = dim_offsets < head_dim

        Q_loaded = tl.load(Q_start_ptr + dim_offsets, mask=mask)
        K_loaded = tl.load(K_start_ptr + dim_offsets, mask=mask)

        # Split into even and odd parts using our computed indices manually
        q_even = tl.zeros((BLOCK_SIZE,), dtype=Q.dtype)
                 ^