In [1]:
import os
import time
import math
import torch
import triton
import triton.language as tl
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM

@triton.jit
def chunk_attention_kernel(
    Q, K, V, Output,
    stride_qb, stride_qm, stride_qh,
    stride_kb, stride_km, stride_kh,
    stride_vb, stride_vm, stride_vh,
    stride_ob, stride_om, stride_oh,
    seq_len, chunk_size, head_dim,
    BLOCK_SIZE: tl.constexpr,
    HEAD_DIM: tl.constexpr,
):
    pid_batch = tl.program_id(0)
    pid_head = tl.program_id(1)
    pid_chunk = tl.program_id(2)
    
    chunk_start = pid_chunk * BLOCK_SIZE
    offs_m = tl.arange(0, BLOCK_SIZE)
    offs_n = tl.arange(0, HEAD_DIM)
    mask = offs_m < (seq_len - chunk_start)
    
    q_base = Q + pid_batch * stride_qb + pid_head * stride_qh
    k_base = K + pid_batch * stride_kb + pid_head * stride_kh
    v_base = V + pid_batch * stride_vb + pid_head * stride_vh
    
    # Compute pointers for the current chunk
    q_ptrs = q_base + chunk_start * stride_qm + offs_m[:, None] * stride_qm + offs_n[None, :] * 1
    k_ptrs = k_base + chunk_start * stride_km + offs_m[:, None] * stride_km + offs_n[None, :] * 1
    v_ptrs = v_base + chunk_start * stride_vm + offs_m[:, None] * stride_vm + offs_n[None, :] * 1
    
    q = tl.load(q_ptrs, mask=mask[:, None], other=0.0)
    k = tl.load(k_ptrs, mask=mask[:, None], other=0.0)
    v = tl.load(v_ptrs, mask=mask[:, None], other=0.0)
    
    # Intra-chunk attention
    scores = tl.dot(q, tl.trans(k)) * (1.0 / tl.sqrt(tl.cast(HEAD_DIM, tl.float32)))
    scores = tl.where(mask[:, None] & mask[None, :], scores, float("-inf"))
    
    max_scores = tl.max(scores, axis=1)
    scores = scores - max_scores[:, None]
    exp_scores = tl.exp(scores)
    sum_exp_scores = tl.sum(exp_scores, axis=1)
    probs = exp_scores / sum_exp_scores[:, None]
    
    intra_output = tl.dot(probs, v)
    
    # Inter-chunk attention (mean of Q*K)
    # Compute mean(Q) and mean(K) across the chunk
    mean_q = tl.sum(q, axis=0) / BLOCK_SIZE
    mean_k = tl.sum(k, axis=0) / BLOCK_SIZE
    mean_scores = tl.dot(mean_q, tl.trans(mean_k)) * (1.0 / tl.sqrt(tl.cast(HEAD_DIM, tl.float32)))
    
    # Apply softmax to mean_scores
    mean_scores = tl.where(tl.ones_like(mean_scores) > 0, mean_scores, float("-inf"))  # Masking if needed
    max_mean_scores = tl.max(mean_scores)
    mean_scores = mean_scores - max_mean_scores
    exp_mean_scores = tl.exp(mean_scores)
    sum_exp_mean_scores = tl.sum(exp_mean_scores)
    mean_probs = exp_mean_scores / sum_exp_mean_scores
    
    # Compute mean(V) across the chunk
    mean_v = tl.sum(v, axis=0) / BLOCK_SIZE
    inter_output = mean_probs * mean_v
    
    # Combine intra and inter chunk outputs
    combined_output = intra_output + inter_output
    
    # Store the combined output
    o_base = Output + pid_batch * stride_ob + pid_head * stride_oh
    out_ptrs = o_base + chunk_start * stride_om + offs_m[:, None] * stride_om + offs_n[None, :] * 1
    tl.store(out_ptrs, combined_output, mask=mask[:, None])

class ChunkedAttention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, chunk_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.chunk_size = chunk_size
        self.head_dim = embed_dim // num_heads

        # Ensure divisibility
        assert embed_dim % num_heads == 0, (
            f"Embedding dimension ({embed_dim}) must be divisible by the number of heads ({num_heads})."
        )

        self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
        self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
        self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
        self.out_proj = torch.nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
        batch_size, seq_len, embed_dim = hidden_states.shape

        # Validate dimensions
        assert embed_dim == self.embed_dim, (
            f"Input embedding dimension ({embed_dim}) does not match expected dimension ({self.embed_dim})."
        )

        q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        output = torch.zeros_like(q)
        num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size

        grid = (batch_size, self.num_heads, num_chunks)
        chunk_attention_kernel[grid](
            q, k, v, output,
            q.stride(0), q.stride(2), q.stride(1),
            k.stride(0), k.stride(2), k.stride(1),
            v.stride(0), v.stride(2), v.stride(1),
            output.stride(0), output.stride(2), output.stride(1),
            seq_len, self.chunk_size, self.head_dim,
            BLOCK_SIZE=min(128, self.chunk_size),
            HEAD_DIM=self.head_dim
        )
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        return self.out_proj(output)


def replace_llamaattention_with_chunked_attention(model, chunk_size: int):
    """
    Replaces all instances of LlamaAttention in the model with ChunkedLlamaAttention.
    """
    from transformers.models.llama.modeling_llama import LlamaAttention

    class ChunkedLlamaAttention(LlamaAttention):
        def __init__(self, config):
            super().__init__(config)
            self.chunked_attention = ChunkedAttention(
                embed_dim=config.hidden_size,
                num_heads=config.num_attention_heads,
                chunk_size=chunk_size
            )
        
        def forward(
            self,
            hidden_states,
            attention_mask=None,
            position_ids=None,
            past_key_value=None,
            use_cache=False,
            output_attentions=False,
            **kwargs
        ):
            return self.chunked_attention(hidden_states, attention_mask, position_ids, **kwargs)

    # Recursively replace LlamaAttention with ChunkedLlamaAttention
    for name, module in model.named_modules():
        if isinstance(module, LlamaAttention):
            parent = model
            components = name.split('.')
            for comp in components[:-1]:
                parent = getattr(parent, comp)
            setattr(parent, components[-1], ChunkedLlamaAttention(model.config))
    
    return model

def benchmark_attention(model_name, chunk_size, sequence_length, batch_size):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.eval().cuda()

    input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length)).cuda()

    # Standard Attention Benchmark
    torch.cuda.empty_cache()
    start_time = time.time()
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    standard_time = time.time() - start_time
    standard_loss = torch.nn.functional.cross_entropy(
        outputs.logits[:, :-1].reshape(-1, outputs.logits.size(-1)),
        input_ids[:, 1:].reshape(-1),
    )

    # Replace Attention Mechanism
    model = replace_llamaattention_with_chunked_attention(model, chunk_size)
    model.cuda()

    # Chunked Attention Benchmark
    torch.cuda.empty_cache()
    start_time = time.time()
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    chunked_time = time.time() - start_time
    chunked_loss = torch.nn.functional.cross_entropy(
        outputs.logits[:, :-1].reshape(-1, outputs.logits.size(-1)),
        input_ids[:, 1:].reshape(-1),
    )

    print(f"Standard: Time={standard_time:.3f}s, Perplexity={torch.exp(standard_loss):.3f}")
    print(f"Chunked: Time={chunked_time:.3f}s, Perplexity={torch.exp(chunked_loss):.3f}")

# Example usage
if __name__ == "__main__":
    # Determine chunk size as sqrt(N)
    sequence_length = 512
    chunk_size = int(math.sqrt(sequence_length))  # e.g., 22 for N=512
    benchmark_attention(
        model_name="meta-llama/Llama-3.2-1B",
        chunk_size=chunk_size,
        sequence_length=sequence_length,
        batch_size=1
    )


Instantiating ChunkedLlamaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.


AttributeError: module 'triton.language' has no attribute 'ones_like'