In [2]:
import torch
import math
from einops import einsum
import torch.nn.functional as F

In [14]:
class SDPA(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, mask):
        d_k = K.shape[-1]
        attention_scores = einsum(
            Q, K, "... query d_k, ... key d_k -> ... query key"
        ) / math.sqrt(d_k)
        if mask is not None:
            attention_scores = torch.where(mask, attention_scores, float("-inf"))
        attention_weights = F.softmax(attention_scores, dim=-1)
        return einsum(
            attention_weights, V, "... query key, ... key d_v ->  ... query d_v"
        )


def generate_random_inputs(batch_size, seq_len, d_model, device):
    Q = torch.randn(batch_size, seq_len, d_model, device=device)
    K = torch.randn(batch_size, seq_len, d_model, device=device)
    V = torch.randn(batch_size, seq_len, d_model, device=device)
    mask = torch.ones(batch_size, seq_len, seq_len, device=device).bool()
    return Q, K, V, mask


B = 8
d_models = [16, 32, 64, 128]
seq_lens = [256, 1024, 4096, 8192, 16384]

In [15]:
import pandas as pd
import time

results = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = torch.compile(SDPA().to(device))

for d_model in d_models:
    for seq_len in seq_lens:
        try:
            # (c) Create random inputs
            Q, K, V, mask = generate_random_inputs(B, seq_len, d_model, device)
            Q.requires_grad = True
            K.requires_grad = True
            V.requires_grad = True

            # (f) Warm up
            for _ in range(5):
                out = model(Q, K, V, mask)
                loss = out.sum()
                loss.backward()
                Q.grad = None
                K.grad = None
                V.grad = None
            torch.cuda.synchronize()

            # (d) Time 100 forward passes
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            start_event.record()
            for _ in range(100):
                out = model(Q, K, V, mask)
                torch.cuda.synchronize()
            end_event.record()
            torch.cuda.synchronize()
            fwd_time = start_event.elapsed_time(end_event) / 100

            # (e) Measure memory before backward pass
            # Run one forward pass to populate the graph and activations
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
            out = model(Q, K, V, mask)
            # Memory in use (activations + parameters + inputs)
            mem_allocated = torch.cuda.memory_allocated() / (1024**2)  # MB

            # (e) Time 100 backward passes
            grad_output = torch.randn_like(out)
            total_bwd_time = 0

            for _ in range(100):
                # Re-run forward to ensure valid graph for each backward pass
                # We don't time this forward pass
                out = model(Q, K, V, mask)
                torch.cuda.synchronize()

                start_event.record()
                out.backward(grad_output)
                torch.cuda.synchronize()
                end_event.record()
                torch.cuda.synchronize()

                total_bwd_time += start_event.elapsed_time(end_event)

                # Zero gradients
                Q.grad = None
                K.grad = None
                V.grad = None

            bwd_time = total_bwd_time / 100

            results.append(
                {
                    "d_model": d_model,
                    "seq_len": seq_len,
                    "fwd_time_ms": fwd_time,
                    "bwd_time_ms": bwd_time,
                    "mem_mb": mem_allocated,
                }
            )
            print(
                f"d_model={d_model}, seq_len={seq_len}: Fwd={fwd_time:.2f}ms, Bwd={bwd_time:.2f}ms, Mem={mem_allocated:.2f}MB"
            )

        except RuntimeError as e:
            print(f"Failed for d_model={d_model}, seq_len={seq_len}: {e}")
            torch.cuda.empty_cache()

df = pd.DataFrame(results)
print(df)

Using device: cuda


W1208 17:24:08.268000 686475 torch/fx/experimental/symbolic_shapes.py:6833] [0/1] _maybe_guard_rel() was called on non-relation expression Eq(s92, 1) | Eq(s92, s24)


d_model=16, seq_len=256: Fwd=0.16ms, Bwd=0.37ms, Mem=467.25MB
d_model=16, seq_len=1024: Fwd=0.22ms, Bwd=0.40ms, Mem=442.38MB
d_model=16, seq_len=4096: Fwd=3.25ms, Bwd=4.67ms, Mem=1048.75MB


W1208 17:24:16.033000 686475 torch/fx/experimental/symbolic_shapes.py:6833] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s92, 1) | Eq(s92, s24)


d_model=16, seq_len=8192: Fwd=13.17ms, Bwd=18.36ms, Mem=2978.25MB


W1208 17:24:49.935000 686475 torch/fx/experimental/symbolic_shapes.py:6833] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s46, 1) | Eq(s46, s60)
W1208 17:24:50.038000 686475 torch/fx/experimental/symbolic_shapes.py:6833] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s92, 1) | Eq(s92, s24)


d_model=16, seq_len=16384: Fwd=52.72ms, Bwd=72.40ms, Mem=10676.25MB
d_model=32, seq_len=256: Fwd=0.18ms, Bwd=0.40ms, Mem=411.75MB
d_model=32, seq_len=1024: Fwd=0.22ms, Bwd=0.41ms, Mem=444.50MB
d_model=32, seq_len=4096: Fwd=3.45ms, Bwd=4.99ms, Mem=1057.25MB


W1208 17:24:58.492000 686475 torch/fx/experimental/symbolic_shapes.py:6833] [0/4] _maybe_guard_rel() was called on non-relation expression Eq(s46, 1) | Eq(s46, s60)
W1208 17:24:58.580000 686475 torch/fx/experimental/symbolic_shapes.py:6833] [0/4] _maybe_guard_rel() was called on non-relation expression Eq(s92, 1) | Eq(s92, s24)


d_model=32, seq_len=8192: Fwd=15.64ms, Bwd=22.25ms, Mem=2996.25MB
d_model=32, seq_len=16384: Fwd=54.65ms, Bwd=77.90ms, Mem=10712.25MB
d_model=64, seq_len=256: Fwd=0.18ms, Bwd=0.38ms, Mem=420.75MB
d_model=64, seq_len=1024: Fwd=0.29ms, Bwd=0.44ms, Mem=448.75MB
d_model=64, seq_len=4096: Fwd=2.90ms, Bwd=5.35ms, Mem=1074.25MB
d_model=64, seq_len=8192: Fwd=13.66ms, Bwd=21.97ms, Mem=3032.25MB
d_model=64, seq_len=16384: Fwd=60.03ms, Bwd=89.50ms, Mem=10784.25MB
d_model=128, seq_len=256: Fwd=0.18ms, Bwd=0.38ms, Mem=438.75MB
d_model=128, seq_len=1024: Fwd=0.38ms, Bwd=0.57ms, Mem=457.25MB
d_model=128, seq_len=4096: Fwd=4.60ms, Bwd=8.40ms, Mem=1108.25MB
d_model=128, seq_len=8192: Fwd=20.09ms, Bwd=33.62ms, Mem=3104.25MB
d_model=128, seq_len=16384: Fwd=87.35ms, Bwd=136.72ms, Mem=10928.25MB
    d_model  seq_len  fwd_time_ms  bwd_time_ms        mem_mb
0        16      256     0.162686     0.365602    467.250488
1        16     1024     0.216485     0.400899    442.375488
2        16     4096     3.2506

In [16]:
df

Unnamed: 0,d_model,seq_len,fwd_time_ms,bwd_time_ms,mem_mb
0,16,256,0.162686,0.365602,467.250488
1,16,1024,0.216485,0.400899,442.375488
2,16,4096,3.250614,4.668465,1048.750488
3,16,8192,13.170645,18.360616,2978.250488
4,16,16384,52.716768,72.403973,10676.250488
5,32,256,0.180307,0.398755,411.750488
6,32,1024,0.222639,0.410617,444.500488
7,32,4096,3.451696,4.988198,1057.250488
8,32,8192,15.641556,22.247633,2996.250488
9,32,16384,54.651079,77.900914,10712.250488


In [18]:
df.to_csv("compiled_pytorch_results.csv", index=False)

# Weighted Sum Example

In [17]:
def weighted_sum(x, weight):
    # Here, assume that x has n-dim shape [..., D], and weight has 1D shape [D]
    return (weight * x).sum(dim=-1)

In [30]:
from tkinter import W
import triton
import triton.language as tl
from einops import rearrange
import torch


@triton.jit
def weighted_sum_fwd(
    x_ptr,
    weight_ptr,
    output_ptr,
    x_stride_row,
    x_stride_dim,
    weight_stride_dim,
    output_stride_row,
    ROWS,
    D,
    ROWS_TILE_SIZE: tl.constexpr,
    D_TILE_SIZE: tl.constexpr,
):
    row_tile_idx = tl.program_id(0)
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(
            ROWS,
            D,
        ),
        strides=(
            x_stride_row,
            x_stride_dim,
        ),
        offsets=(
            row_tile_idx * ROWS_TILE_SIZE,
            0,
        ),
        block_shape=(
            ROWS_TILE_SIZE,
            D_TILE_SIZE,
        ),
        order=(
            1,
            0,
        ),
    )

    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(weight_stride_dim,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )

    output_block_ptr = tl.make_block_ptr(
        output_ptr,
        shape=(ROWS,),
        strides=(output_stride_row,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )

    output = tl.zeros((ROWS_TILE_SIZE,), dtype=tl.float32)
    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")
        output += tl.sum(row * weight[None, :], axis=1)
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE))
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,))
    tl.store(output_block_ptr, output, boundary_check=(0,))


@triton.jit
def weighted_sum_backward(
    x_ptr,
    weight_ptr,  # Input
    grad_output_ptr,  # Grad input
    grad_x_ptr,
    partial_grad_weight_ptr,  # Grad outputs
    stride_xr,
    stride_xd,
    stride_wr,
    stride_gr,
    stride_gxr,
    stride_gxd,
    stride_gwb,
    stride_gwd,
    NUM_ROWS,
    D,
    ROWS_TILE_SIZE: tl.constexpr,
    D_TILE_SIZE: tl.constexpr,
):
    row_tile_idx = tl.program_id(0)
    n_row_tiles = tl.num_programs(0)

    # Inputs
    grad_output_block_ptr = tl.make_block_ptr(
        grad_output_ptr,
        shape=(NUM_ROWS,),
        strides=(stride_gr,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )

    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(
            NUM_ROWS,
            D,
        ),
        strides=(
            stride_xr,
            stride_xd,
        ),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )

    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(stride_wr,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )

    grad_x_block_ptr = tl.make_block_ptr(
        grad_x_ptr,
        shape=(
            NUM_ROWS,
            D,
        ),
        strides=(
            stride_gxr,
            stride_gxd,
        ),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )

    partial_grad_weight_block_ptr = tl.make_block_ptr(
        partial_grad_weight_ptr,
        shape=(
            n_row_tiles,
            D,
        ),
        strides=(
            stride_gwb,
            stride_gwd,
        ),
        offsets=(row_tile_idx, 0),
        block_shape=(1, D_TILE_SIZE),
        order=(1, 0),
    )

    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        grad_output = tl.load(
            grad_output_block_ptr, boundary_check=(0,), padding_option="zero"
        )  # (ROWS_TILE_SIZE,)

        # Outer product for grad_x
        weight = tl.load(
            weight_block_ptr, boundary_check=(0,), padding_option="zero"
        )  # (D_TILE_SIZE,)
        grad_x_row = grad_output[:, None] * weight[None, :]
        tl.store(grad_x_block_ptr, grad_x_row, boundary_check=(0, 1))

        # Reduce as many rows as possible for the grad_weight result
        row = tl.load(
            x_block_ptr, boundary_check=(0, 1), padding_option="zero"
        )  # (ROWS_TILE_SIZE, D_TILE_SIZE)
        grad_weight_row = tl.sum(row * grad_output[:, None], axis=0, keepdim=True)
        tl.store(partial_grad_weight_block_ptr, grad_weight_row, boundary_check=(1,))

        # Move the pointers to the next tile along D
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE))
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,))
        partial_grad_weight_block_ptr = partial_grad_weight_block_ptr.advance(
            (0, D_TILE_SIZE)
        )
        grad_x_block_ptr = grad_x_block_ptr.advance((0, D_TILE_SIZE))


class WeightedSumFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight):
        D, output_dims = x.shape[-1], x.shape[:-1]
        input_shape = x.shape
        ROWS = 1
        x = rearrange(x, "... d -> (...) d")
        ctx.save_for_backward(x, weight)
        assert len(weight.shape) == 1 and weight.shape[0] == D, "Dimension mismatch"
        assert x.is_cuda and weight.is_cuda, "Expected CUDA tensors"
        assert x.is_contiguous(), "Our pointer arithmetic will assume contiguous x"
        ctx.D_TILE_SIZE = triton.next_power_of_2(D) // 16
        ctx.ROWS_TILE_SIZE = 16
        ctx.input_shape = input_shape
        y = torch.empty(output_dims, device=x.device)
        n_rows = y.numel()
        weighted_sum_fwd[(triton.cdiv(n_rows, ctx.ROWS_TILE_SIZE),)](
            x,
            weight,
            y,
            x.stride(0),
            x.stride(1),
            weight.stride(0),
            y.stride(0),
            ROWS=n_rows,
            D=D,
            ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE,
            D_TILE_SIZE=ctx.D_TILE_SIZE,
        )
        return y.view(input_shape[:-1])

    @staticmethod
    def backward(ctx, grad_out):
        x, weight = ctx.saved_tensors
        ROWS_TILE_SIZE, D_TILE_SIZE = (
            ctx.ROWS_TILE_SIZE,
            ctx.D_TILE_SIZE,
        )  # These don't have to be the same
        n_rows, D = x.shape

        # Our strategy is for each thread block to first write to a partial buffer,
        # then we reduce over this buffer to get the final gradient.
        partial_grad_weight = torch.empty(
            (triton.cdiv(n_rows, ROWS_TILE_SIZE), D), device=x.device, dtype=x.dtype
        )
        grad_x = torch.empty_like(x)

        weighted_sum_backward[(triton.cdiv(n_rows, ROWS_TILE_SIZE),)](
            x,
            weight,
            grad_out,
            grad_x,
            partial_grad_weight,
            x.stride(0),
            x.stride(1),
            weight.stride(0),
            grad_out.stride(0),
            grad_x.stride(0),
            grad_x.stride(1),
            partial_grad_weight.stride(0),
            partial_grad_weight.stride(1),
            NUM_ROWS=n_rows,
            D=D,
            ROWS_TILE_SIZE=ROWS_TILE_SIZE,
            D_TILE_SIZE=D_TILE_SIZE,
        )

        grad_weight = partial_grad_weight.sum(axis=0)
        return grad_x, grad_weight


f_weightedsum = WeightedSumFunc.apply

In [29]:
x = torch.rand(128, 16).to("cuda")
w = torch.rand(16).to("cuda")
a = f_weightedsum(x, w)
b = weighted_sum(x, w)

print("forward a shape:", a.shape)
print("forward b shape:", b.shape)
print("forward max abs diff:", (a - b).abs().max().item())

forward a shape: torch.Size([128])
forward b shape: torch.Size([128])
forward max abs diff: 4.76837158203125e-07


In [31]:
x1 = x.clone().detach().requires_grad_(True)
w1 = w.clone().detach().requires_grad_(True)
a1 = f_weightedsum(x1, w1)
loss_a = a1.sum()
loss_a.backward()
grad_x_a = x1.grad.clone()
grad_w_a = w1.grad.clone()

# 另一套用于 PyTorch baseline
x2 = x.clone().detach().requires_grad_(True)
w2 = w.clone().detach().requires_grad_(True)
b1 = weighted_sum(x2, w2)
loss_b = b1.sum()
loss_b.backward()
grad_x_b = x2.grad.clone()
grad_w_b = w2.grad.clone()

print("grad_x max abs diff:", (grad_x_a - grad_x_b).abs().max().item())
print("grad_w max abs diff:", (grad_w_a - grad_w_b).abs().max().item())

CompilationError: at 103:26:
        # Outer product for grad_x
        weight = tl.load(
            weight_block_ptr, boundary_check=(0,), padding_option="zero"
        )  # (D_TILE_SIZE,)
        grad_x_row = grad_output[:, None] * weight[None, :]
        tl.store(grad_x_block_ptr, grad_x_row, boundary_check=(0, 1))

        # Reduce as many rows as possible for the grad_weight result
        row = tl.load(
            x_block_ptr, boundary_check=(0, 1), padding_option="zero"
        )  # (ROWS_TILE_SIZE, D_TILE_SIZE)
        grad_weight_row = tl.sum(row * grad_output[:, None], axis=0, keepdim=True)
                          ^
TypeError("sum() got an unexpected keyword argument 'keepdim'")

In [1]:
import torch
from torch import Tensor
from jaxtyping import Float, Bool, jaxtyped
from beartype import beartype
import math
import torch.nn.functional as F
import triton
import triton.language as tl


class TorchFlashAttention2(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        Q: Float[Tensor, "batch q_len d_model"],
        K: Float[Tensor, "batch k_len d_model"],
        V: Float[Tensor, "batch v_len d_model"],
        mask: Float[Tensor, "batch q_len k_len"],
    ):
        B_q = 32
        B_kv = 16
        bs, q_len, d_model = Q.shape
        T_q = q_len // B_q
        T_k = K.shape[1] // B_kv
        O = torch.zeros((bs, q_len, d_model))
        L = torch.zeros((bs, T_q, B_q))

        for b in range(bs):
            for i in range(T_q):
                q_tile = Q[b, i * B_q : (i + 1) * B_q, :]
                o_tile = torch.zeros((B_q, d_model))
                l = torch.zeros((B_q,))
                m = torch.full((B_q,), -torch.inf)
                for j in range(T_k):
                    k_tile = K[b, j * B_kv : (j + 1) * B_kv, :]
                    v_tile = V[b, j * B_kv : (j + 1) * B_kv, :]
                    s_tile = q_tile @ k_tile.T / math.sqrt(d_model)
                    row_max = torch.max(s_tile, dim=-1).values
                    old_m = m
                    m = torch.maximum(row_max, m)
                    scale_factor = torch.exp(old_m - m)
                    p_tile = torch.exp(s_tile - m.unsqueeze(-1))  #  softmax 分子
                    l = scale_factor * l + torch.sum(p_tile, dim=-1)  # softmax 分母
                    o_tile = scale_factor.unsqueeze(-1) * o_tile + p_tile @ v_tile
                L[b, i, :] = m + torch.log(l)
                O[b, i * B_q : (i + 1) * B_q, :] = o_tile / l.unsqueeze(-1)
        ctx.save_for_backward(Q, K, V, mask, O, L)
        return O

    @staticmethod
    def backward(ctx, grad_output):
        raise NotImplementedError()

In [8]:
import torch
from torch import Tensor
from jaxtyping import Float, Bool, jaxtyped
from beartype import beartype
import math
import torch.nn.functional as F
import triton
import triton.language as tl

# fmt: off
@triton.jit
def flash_fwd_kernel(
    Q_ptr, K_ptr, V_ptr, # Inputs
    O_ptr, L_ptr, # Outputs
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_vb, stride_vk, stride_vd,
    stride_ob, stride_oq, stride_od,
    stride_lb, stride_lq,
    N_QUERIES, N_KEYS,
    scale,
    D: tl.constexpr,
    Q_TILE_SIZE: tl.constexpr,
    K_TILE_SIZE: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
): 
# fmt: on
    query_tile_index = tl.program_id(0) # parallize over queries because seq_len dim is embarassingly parallelized
    batch_index = tl.program_id(1)
    
    Q_block_ptr = tl.make_block_ptr(
        Q_ptr + batch_index * stride_qb,
        shape=(N_QUERIES, D),
        strides=(stride_qq, stride_qd),
        offsets=(query_tile_index * Q_TILE_SIZE, 0),
        block_shape=(Q_TILE_SIZE, D),
        order=(1, 0),
    )

    O_block_ptr = tl.make_block_ptr(
        O_ptr + batch_index * stride_ob,
        shape=(N_QUERIES, D),
        strides=(stride_oq, stride_od),
        offsets=(query_tile_index * Q_TILE_SIZE, 0),
        block_shape=(Q_TILE_SIZE, D),
        order=(1, 0),
    )
    
    L_block_ptr = tl.make_block_ptr(
        L_ptr + batch_index * stride_lb,
        shape=(N_QUERIES,),
        strides=(stride_lq,),
        offsets=(query_tile_index * Q_TILE_SIZE,),
        block_shape=(Q_TILE_SIZE,),
        order=(0,),
    )
    
    K_block_ptr = tl.make_block_ptr(
        K_ptr + batch_index * stride_kb,
        shape=(N_KEYS, D),
        strides=(stride_kk, stride_kd),
        offsets=(0, 0),
        block_shape=(K_TILE_SIZE, D),
        order=(1, 0),
    )

    V_block_ptr = tl.make_block_ptr(
        V_ptr + batch_index * stride_vb,
        shape=(N_KEYS, D),
        strides=(stride_vk, stride_vd),
        offsets=(0, 0),
        block_shape=(K_TILE_SIZE, D),
        order=(1, 0),
    )

    q = tl.load(Q_block_ptr)
    o = tl.zeros((Q_TILE_SIZE, D), dtype=tl.float32)
    l = tl.zeros((Q_TILE_SIZE,), dtype=tl.float32)
    m = tl.full((Q_TILE_SIZE,), -float("inf"), dtype=tl.float32)
    
    hi = N_KEYS
    if IS_CAUSAL:
        hi = min((query_tile_index + 1) * Q_TILE_SIZE, N_KEYS)
    
    offs_m = query_tile_index * Q_TILE_SIZE + tl.arange(0, Q_TILE_SIZE)   
    for start_k in range(0, hi, K_TILE_SIZE):
        k = tl.load(K_block_ptr)
        v = tl.load(V_block_ptr)
        s = tl.dot(q, k.T) * scale
        offs_n = start_k + tl.arange(0, K_TILE_SIZE)
        if IS_CAUSAL and (start_k + K_TILE_SIZE > query_tile_index * Q_TILE_SIZE):
                mask = offs_n[None, :] > offs_m[:, None]
                s = tl.where(mask, -1.0e6, s)
        row_max = tl.max(s, axis=-1)
        old_m = m
        m = tl.maximum(row_max, m)
        scale_factor = tl.exp(old_m - m)
        p = tl.exp(s - m[:, None])
        l = scale_factor * l + tl.sum(p, axis=-1)
        o = o * scale_factor[:, None]
        o = tl.dot(p.to(v.dtype), v, acc=o)
        K_block_ptr = tl.advance(K_block_ptr, (K_TILE_SIZE, 0))
        V_block_ptr = tl.advance(V_block_ptr, (K_TILE_SIZE, 0))
    o = o / l[:, None]
    l = m + tl.log(l)
    tl.store(O_block_ptr, o.to(O_block_ptr.type.element_ty))
    tl.store(L_block_ptr, l)

    

    
class TritonFlashAttention2(torch.autograd.Function):
    @staticmethod
    def forward(ctx,
        Q: Float[Tensor, "batch q_len d_model"],
        K: Float[Tensor, "batch k_len d_model"],
        V: Float[Tensor, "batch v_len d_model"],
        is_causal: bool
    ):
        Q_TILE_SIZE = 32
        K_TILE_SIZE = 16
        bs, N_QUERIES, d_model = Q.shape
        N_KEYS = K.shape[1]
        T_q = triton.cdiv(N_QUERIES, Q_TILE_SIZE)
        O = torch.zeros((bs, N_QUERIES, d_model), device = Q.device)
        L = torch.zeros((bs, N_QUERIES), device = Q.device)
        grid = (T_q, bs)
        flash_fwd_kernel[grid](
            Q, K, V, 
            O, L,Q.stride(0), Q.stride(1), Q.stride(2),
            K.stride(0), K.stride(1), K.stride(2),
            V.stride(0), V.stride(1), V.stride(2),
            O.stride(0), O.stride(1), O.stride(2),
            L.stride(0), L.stride(1),
            N_QUERIES, N_KEYS,
            scale=1.0 / (d_model ** 0.5),
            D=d_model, # type: ignore
            Q_TILE_SIZE=Q_TILE_SIZE, # type: ignore
            K_TILE_SIZE=K_TILE_SIZE, # type: ignore
            IS_CAUSAL=is_causal, # type: ignore
        )
        ctx.save_for_backward(Q, K, V, O, L)
        ctx.is_causal = is_causal
        return O

    @staticmethod
    def backward(ctx, grad_out):
        raise NotImplementedError()
    
from cs336_systems.flash_attn import TorchFlashAttention2

f_torch_flash_attn = TorchFlashAttention2.apply
f_triton_flash_attn = TritonFlashAttention2.apply
# Test
bs = 2
seq_len = 32
d_model = 128
size = (bs, seq_len, d_model)


with torch.device("cuda:0"):
    q = torch.rand(size)
    k = torch.rand(size)
    v = torch.rand(size)
    mask = torch.rand((seq_len, seq_len))
    res = f_triton_flash_attn(q, k, v, True)
    ref = f_torch_flash_attn(q, k, v, True)
    # ref = F.scaled_dot_product_attention(q, k, v, is_causal=False)
    torch.testing.assert_close(res, ref, rtol=1e-2, atol=1e-2)