In [1]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.runtime import driver
from pprint import pprint
import torch.nn as nn

In [2]:
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [3]:
properties = driver.active.utils.get_device_properties(DEVICE.index)
pprint(f"{properties=}", underscore_numbers=True)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}

("properties={'max_shared_mem': 101376, 'max_num_regs': 65536, "
 "'multiprocessor_count': 64, 'warpSize': 32, 'sm_clock_rate': 1695000, "
 "'mem_clock_rate': 8001000, 'mem_bus_width': 384}")


In [4]:
# import os
# os.environ["TRITON_INTERPRET"] = "1"

In [5]:
@triton.jit
def _swiglu_forward_kernel(input_ptr: torch.Tensor, up_ptr: torch.Tensor, gate_ptr:torch.Tensor,
                         output_ptr: torch.Tensor, input_batch_stride: int, input_seq_stride: int, 
                         output_batch_stride: int, output_seq_stride: int,
                         L: int, H: int, O:int, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_O: tl.constexpr):
    '''
    Triton kernel for gelu.
    define silu(x) = x * sigmoid(x)
    gelu(x, up, gate) = up(x) * silu(gate(x))
    x \in R^{B x L x H}
    up \in R^{O x H} \implies up(x) \in R^{B x L x O}
    gate \in R^{O x H} \implies gate(x) \in R^{B x L x O}
    gelu \in R^{B x L x O}
    Args:
        input_ptr: Pointer to the input, shape: (B, L, H)
        up_ptr: Pointer to the weights of Linear layer, shape: (O, H)
        gate_ptr: Pointer to the weights of the Linear layer, shaoe: (O, H)
        output_ptr: Pointer to output, shape: (B, L, O)
        input_batch_stride: number of elements we move to reach next batch in the input
        input_seq_stride: number of elements we move to reach next sequence in the input
        output_batch_stride: number of elements we move to reach next batch in the output
        output_seq_stride: number of elements we move to reach next sequence in the output
        L: Sequence Length
        H: Embedding Dimension
        O: Output Dimension
    '''
    pid_BL = tl.program_id(0)
    batch_idx = pid_BL // L
    seq_idx = pid_BL % L
    input_base_idx = batch_idx * input_batch_stride + seq_idx * input_seq_stride 
    output_base_idx = batch_idx * output_batch_stride + seq_idx * output_seq_stride
    input_start_ptr = input_ptr + input_base_idx
    output_start_ptr = output_ptr + output_base_idx
    
    pid = tl.program_id(1)
    num_pid_h = tl.cdiv(H, BLOCK_SIZE_H)
    num_pid_o = tl.cdiv(O, BLOCK_SIZE_O)
    pid_h = pid // num_pid_o
    pid_o = pid % num_pid_o
    offs_ah = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
    offs_bo = pid_o * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
    offs_h = tl.arange(0, BLOCK_SIZE_H)
    a_ptrs = input_start_ptr + offs_ah # shape (BLOCK_SIZE_H)
    b_ptrs = gate_ptr + offs_bo[:, None] * H  + offs_h[None, :] # shape: (BLOCK_SIZE_O, BLOCK_SIZE_H)
    c_ptrs = up_ptr + offs_bo[:, None] * H + offs_h[None, :]
    gated = tl.zeros((BLOCK_SIZE_O,), dtype=tl.float32)
    up = tl.zeros((BLOCK_SIZE_O, ), dtype=tl.float32)
    
    for k in tl.range(0, tl.cdiv(H, BLOCK_SIZE_H)):
        offs_h = tl.arange(0, BLOCK_SIZE_H)
        a_tile_ptrs = a_ptrs + k * BLOCK_SIZE_H
        mask_a = (offs_ah + k * BLOCK_SIZE_H < H)
        b_tile_ptrs = b_ptrs + k * BLOCK_SIZE_H
        mask_b = (offs_h + k * BLOCK_SIZE_H < H)[None, :]
        c_tile_ptrs = c_ptrs + k * BLOCK_SIZE_H
        a = tl.load(a_tile_ptrs, mask=mask_a, other=0.0)
        b = tl.load(b_tile_ptrs, mask=mask_b, other=0.0)
        c = tl.load(c_tile_ptrs, mask=mask_b, other=0.0)
        gated += tl.sum(b * a, axis=1)
        up += tl.sum(c * a, axis=1)
    
    
    silu = gated * tl.sigmoid(gated)
    output = up * silu
    output_ptrs = output_start_ptr + offs_bo
    mask_output = offs_bo < O
    tl.store(output_ptrs, output, mask=mask_output)

In [23]:
def swiglu(x: torch.Tensor, up_weights: torch.Tensor, gate_weights: torch.Tensor, O: int) -> torch.Tensor:
    B, L, H = x.shape
    
    # # Create weight matrices (these would normally be module parameters)
    # up_weights = torch.empty((O, H), dtype=torch.float32, device=x.device)
    # gate_weights = torch.empty((O, H), dtype=torch.float32, device=x.device)
    
    # # Initialize weights (in a real implementation, these would be trained parameters)
    # # This is just placeholder initialization
    # torch.nn.init.kaiming_uniform_(up_weights)
    # torch.nn.init.kaiming_uniform_(gate_weights)
    
    num_stages = 8
    BLOCK_SIZE_H = 256
    BLOCK_SIZE_O = 512
    grid = lambda META: (B * L, triton.cdiv(H, META['BLOCK_SIZE_H']) * triton.cdiv(O, META['BLOCK_SIZE_O']), )
    output = torch.empty((B, L, O), dtype=torch.float32, device=x.device)
    _swiglu_forward_kernel[grid](x, up_weights, gate_weights, output, x.stride(0), x.stride(1), output.stride(0), output.stride(1), L, H, O, 
                               BLOCK_SIZE_H=BLOCK_SIZE_H, BLOCK_SIZE_O=BLOCK_SIZE_O,num_stages=num_stages)
    
    return output

In [24]:
class Swiglu(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.up = nn.Linear(input_dim, output_dim, bias=False)
        self.gate = nn.Linear(input_dim, output_dim, bias=False)
    
    def forward(self, input_batch: torch.Tensor) -> torch.Tensor:
        """Run forward pass through the layer.
        
        Args:
            input_batch: a torch.Tensor of shape: (B, L, D)
        
        Returns:
            a torch.Tensor of shape: (B, L, O)
        
        """
        gated = self.gate(input_batch)
        silu = gated * F.sigmoid(gated)
        return self.up(input_batch) * silu

In [25]:
B, L, D, O = 16, 64, 256, 1024
x = torch.randn((B, L, D), device='cuda', dtype=torch.float32)
torch_swiglu = Swiglu(D, O).to(x.device)
up_weights = torch_swiglu.up.weight
gate_weights = torch_swiglu.gate.weight
torch_output = torch_swiglu(x)
triton_output = swiglu(x, up_weights, gate_weights, O)



In [26]:
torch_output.shape

torch.Size([16, 64, 1024])

In [27]:
triton_output.shape

torch.Size([16, 64, 1024])

In [33]:
torch.mean(torch.abs(torch_output - triton_output))

tensor(4.3280e-08, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
configs = []
ref_lib = "torch"
output_sizes = [2048]
L = 64
for o in output_sizes:
    configs.append(
        triton.testing.Benchmark(
            x_names=["H"],  # Argument names to use as an x-axis for the plot
            x_vals=[2**i for i in range(8, 20)],  # Different possible values for `x_name`
            line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
            # Possible values for `line_arg`
            # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
            line_vals=[ref_lib.lower(), "triton"] , # Label name for the lines
            line_names=[ref_lib, "Triton"],  # Line styles
            styles=[("green", "-"), ("blue", "-")],
            ylabel="TFLOPS",  # Label name for the y-axis
            plot_name="swiglu-performance-" +
            ("fp32"),  # Name for the plot, used also as a file name for saving the plot.
            args={"L": L, "O": o},
        ))
@triton.testing.perf_report(configs)
def benchmark(H, provider, L, O):
    B = 16
    a = torch.randn((B, L, H), device=DEVICE, dtype=torch.float32)
    torch_swiglu = Swiglu(H, O).to(a.device)
    quantiles = [0.5, 0.2, 0.8]
    if provider == ref_lib.lower():
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_swiglu(a), quantiles=quantiles)
    up_weights, gate_weights = torch_swiglu.up.weight, torch_swiglu.gate.weight
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: swiglu(a, up_weights, gate_weights, O), quantiles=quantiles)
    # Calculate memory bandwidth: bytes_accessed / runtime_in_seconds
    # Each float32 element is 4 bytes
    bytes_accessed = B * L * H * 4 # Input + Output + gamma
    gb_per_s = lambda ms: bytes_accessed * 1e-9 / (ms * 1e-3)
    
    return gb_per_s(ms), gb_per_s(max_ms), gb_per_s(min_ms)


print(benchmark.run(show_plots=True, print_data=True))