In [1]:
import math
from typing import Optional, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import time
from tqdm import tqdm

# A good practice for enabling/disabling checkpointing globally
USE_CHECKPOINTING = True

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # The output is cast to the input's dtype, preventing issues with mixed precision.
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class SwiGLU(nn.Module):
    """
    SwiGLU Feed-Forward Network.
    This is a memory-efficient and performant alternative to the standard FFN.
    Reference: "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202)
    """
    def __init__(
        self,
        dim: int,
        hidden_dim: Optional[int] = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        hidden_dim = hidden_dim or 4 * dim
        # Use a heuristic for the intermediate dim, as in Llama models
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = (hidden_dim + 7) & -8 # Multiple of 8

        self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate, up = self.w13(x).chunk(2, dim=-1)
        x = F.silu(gate) * up
        x = self.w2(x)
        return self.dropout(x)


class MultiheadAttentionBlock(nn.Module):
    """
    Multi-head attention block. Fuses QKV for self-attention, uses separate Q/KV for cross-attention.
    """
    def __init__(
        self,
        dim_q: int,
        dim_kv: int,
        dim_out: int,
        n_heads: int,
        dropout: float = 0.0,
        bias: bool = False,
        is_self_attention: bool = False # Explicit flag
    ):
        super().__init__()
        if dim_out % n_heads != 0:
            raise ValueError(f"dim_out ({dim_out}) must be divisible by n_heads ({n_heads})")

        self.n_heads = n_heads
        self.head_dim = dim_out // n_heads
        self.dropout = dropout
        self.is_self_attention = is_self_attention

        if self.is_self_attention:
            if dim_q != dim_kv:
                raise ValueError("For self-attention, dim_q must be equal to dim_kv.")
            self.w_qkv = nn.Linear(dim_q, 3 * dim_out, bias=bias)
        else:
            self.w_q = nn.Linear(dim_q, dim_out, bias=bias)
            self.w_kv = nn.Linear(dim_kv, 2 * dim_out, bias=bias)

        self.w_o = nn.Linear(dim_out, dim_out, bias=bias)

    def forward(
        self,
        query: torch.Tensor,
        key_value: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, seq_len_q, _ = query.shape
        seq_len_kv = key_value.shape[1]

        if self.is_self_attention:
            # The caller guarantees query is key_value
            q, k, v = self.w_qkv(query).chunk(3, dim=-1)
        else:
            q = self.w_q(query)
            k, v = self.w_kv(key_value).chunk(2, dim=-1)

        q = q.view(batch_size, seq_len_q, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len_kv, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len_kv, self.n_heads, self.head_dim).transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1)
        return self.w_o(attn_output)


class MAB(nn.Module):
    """Multihead Attention Block with pre-normalization, FFN, and optional checkpointing."""
    def __init__(
        self,
        dim_q: int,
        dim_kv: int,
        dim: int,
        n_heads: int,
        ffn_hidden_dim: Optional[int] = None,
        dropout: float = 0.0,
        use_checkpointing: bool = False,
        is_self_attention: bool = False # Add flag here
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpointing = use_checkpointing

        self.norm_q = RMSNorm(dim_q)
        self.norm_kv = RMSNorm(dim_kv) if dim_kv != dim_q else self.norm_q

        self.attention = MultiheadAttentionBlock(
            dim_q=dim_q, dim_kv=dim_kv, dim_out=dim, n_heads=n_heads,
            dropout=dropout, is_self_attention=is_self_attention # Pass flag down
        )
        self.norm_ffn = RMSNorm(dim)
        self.ffn = SwiGLU(dim=dim, hidden_dim=ffn_hidden_dim, dropout=dropout)

    def _forward(self, query: torch.Tensor, key_value: torch.Tensor) -> torch.Tensor:
        normed_q = self.norm_q(query)
        normed_kv = self.norm_kv(key_value)
        attn_output = self.attention(normed_q, normed_kv)

        if query.shape[-1] == self.dim:
            x = query + attn_output
        else:
            x = attn_output

        x = x + self.ffn(self.norm_ffn(x))
        return x

    def forward(self, query: torch.Tensor, key_value: torch.Tensor) -> torch.Tensor:
        if self.training and self.use_checkpointing:
            return checkpoint(self._forward, query, key_value, use_reentrant=False)
        else:
            return self._forward(query, key_value)


class SAB(nn.Module):
    def __init__(self, dim: int, n_heads: int, ffn_hidden_dim: Optional[int] = None, dropout: float = 0.0):
        super().__init__()
        self.mab = MAB(
            dim_q=dim, dim_kv=dim, dim=dim, n_heads=n_heads,
            ffn_hidden_dim=ffn_hidden_dim, dropout=dropout,
            use_checkpointing=USE_CHECKPOINTING,
            is_self_attention=True # This is true self-attention
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mab(x, x)

class ISAB(nn.Module):
    def __init__(self, dim_in: int, dim_out: int, n_heads: int, n_inducing_points: int, ffn_hidden_dim: Optional[int] = None, dropout: float = 0.0):
        super().__init__()
        self.inducing_points = nn.Parameter(torch.randn(1, n_inducing_points, dim_out))
        nn.init.xavier_uniform_(self.inducing_points)

        self.mab_cross = MAB(
            dim_q=dim_out, dim_kv=dim_in, dim=dim_out, n_heads=n_heads,
            ffn_hidden_dim=ffn_hidden_dim, dropout=dropout,
            use_checkpointing=USE_CHECKPOINTING, is_self_attention=False # This is cross-attention
        )
        self.mab_self = MAB(
            dim_q=dim_in, dim_kv=dim_out, dim=dim_out, n_heads=n_heads,
            ffn_hidden_dim=ffn_hidden_dim, dropout=dropout,
            use_checkpointing=USE_CHECKPOINTING, is_self_attention=False # This is cross-attention
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        inducing = self.inducing_points.expand(batch_size, -1, -1)
        h = self.mab_cross(inducing, x)
        return self.mab_self(x, h)

class PMA(nn.Module):
    def __init__(self, dim: int, n_heads: int, n_seeds: int, ffn_hidden_dim: Optional[int] = None, dropout: float = 0.0):
        super().__init__()
        self.seed_vectors = nn.Parameter(torch.randn(1, n_seeds, dim))
        nn.init.xavier_uniform_(self.seed_vectors)
        self.mab = MAB(
            dim_q=dim, dim_kv=dim, dim=dim, n_heads=n_heads,
            ffn_hidden_dim=ffn_hidden_dim, dropout=dropout,
            use_checkpointing=USE_CHECKPOINTING, is_self_attention=False # This is cross-attention
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        seeds = self.seed_vectors.expand(batch_size, -1, -1)
        return self.mab(seeds, x)

class SetTransformer(nn.Module):
    def __init__(
        self, input_dim: int, output_dim: int, model_dim: int = 256, n_heads: int = 8,
        n_isab: int = 2, n_dec_blocks: int = 1, n_inducing_points: Union[int, List[int]] = 32,
        n_seeds: int = 1, ffn_hidden_dim: Optional[int] = None, dropout: float = 0.0
    ):
        super().__init__()
        if isinstance(n_inducing_points, int):
            n_inducing_points = [n_inducing_points] * n_isab
        elif len(n_inducing_points) != n_isab:
            raise ValueError(f"n_inducing_points must be an int or list of length {n_isab}")

        self.input_projection = nn.Linear(input_dim, model_dim)
        self.encoder = nn.ModuleList([
            ISAB(model_dim, model_dim, n_heads, n_inducing_points[i], ffn_hidden_dim, dropout)
            for i in range(n_isab)
        ])
        self.pooling = PMA(model_dim, n_heads, n_seeds, ffn_hidden_dim, dropout)
        self.decoder = nn.ModuleList([
            SAB(model_dim, n_heads, ffn_hidden_dim, dropout) for _ in range(n_dec_blocks)
        ])
        self.output_norm = RMSNorm(model_dim)
        self.output_projection = nn.Linear(model_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_projection(x)
        for isab in self.encoder:
            x = isab(x)
        x = self.pooling(x)
        for sab in self.decoder:
            x = sab(x)
        x = self.output_norm(x)
        x = self.output_projection(x)
        return x


# --- Benchmark ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    print(f"Using device: {device}, AMP dtype: {amp_dtype}")

    # Model and data parameters
    batch_size, set_size, input_dim_feat = 128, 512, 32
    input_dim = input_dim_feat * 10
    output_dim = 10
    model_dim = 512
    n_heads = 8
    n_isab = 6
    n_dec_blocks = 2
    n_inducing_points = 64
    n_seeds = 64 # This determines the output set size

    model = SetTransformer(
        input_dim=input_dim, output_dim=output_dim, model_dim=model_dim, n_heads=n_heads,
        n_isab=n_isab, n_dec_blocks=n_dec_blocks,
        n_inducing_points=n_inducing_points, n_seeds=n_seeds, dropout=0.1
    ).to(device)

    x = torch.randn(batch_size, set_size, input_dim, device=device)
    y = torch.randn(batch_size, n_seeds, output_dim, device=device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.amp.GradScaler(enabled=(amp_dtype == torch.float16))

    print("\n--- Starting Benchmark ---")
    print(f"Model dim: {model_dim}, Heads: {n_heads}, Encoder Blocks: {n_isab}, Decoder Blocks: {n_dec_blocks}")
    print(f"Batch size: {batch_size}, Input set size: {set_size}, Output set size: {n_seeds}")
    print(f"Checkpointing: {'Enabled' if USE_CHECKPOINTING else 'Disabled'}")

    # Warm-up iterations
    print("Running warm-up iterations...")
    for _ in tqdm(range(5)):
        with torch.autocast(device_type=device.type, dtype=amp_dtype):
            output = model(x)
            loss = F.mse_loss(output, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    # Benchmark loop
    num_iterations = 50
    model.train()
    start_time = time.time()

    for i in tqdm(range(num_iterations)):
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type=device.type, dtype=amp_dtype):
            output = model(x)
            loss = F.mse_loss(output, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end_time = time.time()

    total_time = end_time - start_time
    avg_time_per_iter = total_time / num_iterations
    throughput = (batch_size * num_iterations) / total_time

    print("\n--- Benchmark Results ---")
    print(f"Total time for {num_iterations} iterations: {total_time:.2f} seconds")
    print(f"Average time per iteration: {avg_time_per_iter * 1000:.2f} ms")
    print(f"Throughput: {throughput:.2f} samples/sec")


Using device: cuda, AMP dtype: torch.bfloat16

--- Starting Benchmark ---
Model dim: 512, Heads: 8, Encoder Blocks: 6, Decoder Blocks: 2
Batch size: 128, Input set size: 512, Output set size: 64
Checkpointing: Enabled
Running warm-up iterations...


100%|██████████| 5/5 [00:01<00:00,  2.82it/s]
100%|██████████| 50/50 [00:17<00:00,  2.91it/s]



--- Benchmark Results ---
Total time for 50 iterations: 17.69 seconds
Average time per iteration: 353.88 ms
Throughput: 361.71 samples/sec


In [None]:
class RotaryEmbedding(nn.Module):
    """Implements Rotary Positional Embeddings (RoPE)."""
    def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        t = torch.arange(self.max_seq_len, device=self.inv_freq.device)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
        if seq_len > self.max_seq_len:
            raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_emb(
    xq: torch.Tensor, xk: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Applies rotary embeddings to query and key tensors."""
    # cos, sin have shape (1, 1, seq_len, head_dim)
    # xq, xk have shape (batch, n_heads, seq_len, head_dim)
    xq_out = (xq * cos) + (rotate_half(xq) * sin)
    xk_out = (xk * cos) + (rotate_half(xk) * sin)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class Attention(nn.Module):
    """Decoder attention with RoPE and causal masking support."""
    def __init__(self, dim: int, n_heads: int, dropout: float = 0.0, use_rope: bool = False):
        super().__init__()
        if dim % n_heads != 0:
            raise ValueError(f"dim ({dim}) must be divisible by n_heads ({n_heads})")

        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.use_rope = use_rope

        self.w_q = nn.Linear(dim, dim, bias=False)
        self.w_k = nn.Linear(dim, dim, bias=False)
        self.w_v = nn.Linear(dim, dim, bias=False)
        self.w_o = nn.Linear(dim, dim, bias=False)
        self.dropout = dropout

    def forward(
        self,
        query: torch.Tensor,
        key_value: torch.Tensor,
        rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        is_causal: bool = False,
    ) -> torch.Tensor:
        batch_size, seq_len_q, _ = query.shape
        seq_len_kv = key_value.shape[1]

        q = self.w_q(query)
        k = self.w_k(key_value)
        v = self.w_v(key_value)

        q = q.view(batch_size, seq_len_q, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len_kv, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len_kv, self.n_heads, self.head_dim).transpose(1, 2)

        if self.use_rope:
            if rope_emb is None:
                raise ValueError("rope_emb must be provided when use_rope is True")
            cos, sin = rope_emb
            q, k = apply_rotary_emb(q, k, cos, sin)

        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=is_causal,
            dropout_p=self.dropout if self.training else 0.0,
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1)
        return self.w_o(attn_output)


class TransformerDecoderBlock(nn.Module):
    """A single block of the Transformer Decoder."""
    def __init__(
        self,
        dim: int,
        n_heads: int,
        ffn_hidden_dim: Optional[int] = None,
        dropout: float = 0.0,
        use_checkpointing: bool = False,
    ):
        super().__init__()
        self.use_checkpointing = use_checkpointing

        self.self_attn_norm = RMSNorm(dim)
        self.self_attention = Attention(dim=dim, n_heads=n_heads, dropout=dropout, use_rope=True)

        self.cross_attn_norm = RMSNorm(dim)
        self.encoder_mem_norm = RMSNorm(dim)
        self.cross_attention = Attention(dim=dim, n_heads=n_heads, dropout=dropout, use_rope=False)

        self.ffn_norm = RMSNorm(dim)
        self.ffn = SwiGLU(dim=dim, hidden_dim=ffn_hidden_dim, dropout=dropout)

    def _forward(
        self,
        x: torch.Tensor,
        encoder_memory: torch.Tensor,
        rope_emb: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        x = x + self.self_attention(
            self.self_attn_norm(x), self.self_attn_norm(x),
            rope_emb=rope_emb, is_causal=True
        )
        x = x + self.cross_attention(
            self.cross_attn_norm(x), self.encoder_mem_norm(encoder_memory)
        )
        x = x + self.ffn(self.ffn_norm(x))
        return x

    def forward(
        self,
        x: torch.Tensor,
        encoder_memory: torch.Tensor,
        rope_emb: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        if self.training and self.use_checkpointing:
            # Re-define ckpt_fn to be compatible with checkpoint's argument handling
            def ckpt_fn(x, encoder_memory, cos, sin):
                return self._forward(x, encoder_memory, (cos, sin))
            cos, sin = rope_emb
            return checkpoint(ckpt_fn, x, encoder_memory, cos, sin, use_reentrant=False)
        else:
            return self._forward(x, encoder_memory, rope_emb)


class TransformerDecoder(nn.Module):
    """State-of-the-art Transformer Decoder."""
    def __init__(
        self,
        vocab_size: int,
        model_dim: int,
        n_layers: int,
        n_heads: int,
        max_seq_len: int = 4096,
        ffn_hidden_dim: Optional[int] = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.model_dim = model_dim
        self.tok_embeddings = nn.Embedding(vocab_size, model_dim)
        self.rope = RotaryEmbedding(dim=model_dim // n_heads, max_seq_len=max_seq_len)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(
                dim=model_dim, n_heads=n_heads, ffn_hidden_dim=ffn_hidden_dim,
                dropout=dropout, use_checkpointing=USE_CHECKPOINTING,
            ) for _ in range(n_layers)
        ])

        self.output_norm = RMSNorm(model_dim)
        self.output_projection = nn.Linear(model_dim, vocab_size, bias=False)
        self.output_projection.weight = self.tok_embeddings.weight # Weight tying

    def forward(self, tokens: torch.Tensor, encoder_memory: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)
        rope_emb = self.rope(h, seq_len=seq_len)

        for layer in self.layers:
            h = layer(h, encoder_memory, rope_emb)

        h = self.output_norm(h)
        logits = self.output_projection(h)
        return logits


# --- Final Encoder-Decoder Model ---

class SetEncoder(nn.Module):
    """The SetTransformer repurposed as a permutation-invariant encoder."""
    def __init__(
        self, input_dim: int, model_dim: int, n_heads: int, n_isab: int, n_sab: int,
        n_inducing_points: Union[int, List[int]], n_seeds: int,
        ffn_hidden_dim: Optional[int] = None, dropout: float = 0.0
    ):
        super().__init__()
        if isinstance(n_inducing_points, int):
            n_inducing_points = [n_inducing_points] * n_isab
        elif len(n_inducing_points) != n_isab:
            raise ValueError(f"n_inducing_points must be an int or list of length {n_isab}")

        self.input_projection = nn.Linear(input_dim, model_dim)
        self.encoder = nn.ModuleList([
            ISAB(model_dim, model_dim, n_heads, n_inducing_points[i], ffn_hidden_dim, dropout)
            for i in range(n_isab)
        ])
        self.pooling = PMA(model_dim, n_heads, n_seeds, ffn_hidden_dim, dropout)
        self.decoder = nn.ModuleList([
            SAB(model_dim, n_heads, ffn_hidden_dim, dropout) for _ in range(n_sab)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_projection(x)
        for isab in self.encoder:
            x = isab(x)
        x = self.pooling(x)
        for sab in self.decoder:
            x = sab(x)
        return x


class NSRTransformer(nn.Module):
    """Neural Set-to-Sequence Responder: SetTransformer Encoder + SOTA Decoder."""
    def __init__(
        self,
        # Encoder params
        input_dim: int,
        n_isab: int,
        n_sab: int,
        n_inducing_points: Union[int, List[int]],
        n_seeds: int,
        # Decoder params
        vocab_size: int,
        n_dec_layers: int,
        max_seq_len: int,
        # Shared params
        model_dim: int,
        n_heads: int,
        ffn_hidden_dim: Optional[int] = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.encoder = SetEncoder(
            input_dim=input_dim, model_dim=model_dim, n_heads=n_heads,
            n_isab=n_isab, n_sab=n_sab, n_inducing_points=n_inducing_points,
            n_seeds=n_seeds, ffn_hidden_dim=ffn_hidden_dim, dropout=dropout,
        )
        self.decoder = TransformerDecoder(
            vocab_size=vocab_size, model_dim=model_dim, n_layers=n_dec_layers,
            n_heads=n_heads, max_seq_len=max_seq_len,
            ffn_hidden_dim=ffn_hidden_dim, dropout=dropout,
        )

    def forward(self, src_set: torch.Tensor, tgt_seq: torch.Tensor) -> torch.Tensor:
        encoder_memory = self.encoder(src_set)
        logits = self.decoder(tgt_seq, encoder_memory)
        return logits


# --- Benchmark ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    print(f"Using device: {device}, AMP dtype: {amp_dtype}")

    # Model and data parameters
    batch_size, set_size, input_dim = 128, 512, 32*10
    vocab_size, tgt_seq_len = 1024, 256
    model_dim = 512
    n_heads = 8
    n_isab = 8
    n_sab = 8
    n_dec_layers = 8
    n_inducing_points = 64
    n_seeds = 64 # Number of memory vectors from encoder
    max_seq_len = 1024

    model = NSRTransformer(
        input_dim=input_dim, n_isab=n_isab, n_sab=n_sab, n_inducing_points=n_inducing_points,
        n_seeds=n_seeds, vocab_size=vocab_size, n_dec_layers=n_dec_layers,
        max_seq_len=max_seq_len, model_dim=model_dim, n_heads=n_heads, dropout=0.1
    ).to(device)

    # Dummy data
    src_set = torch.randn(batch_size, set_size, input_dim, device=device)
    tgt_tokens = torch.randint(0, vocab_size, (batch_size, tgt_seq_len), device=device)
    # For loss calculation, targets are shifted
    labels = torch.randint(0, vocab_size, (batch_size, tgt_seq_len), device=device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.amp.GradScaler(enabled=(amp_dtype == torch.float16))

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n--- Starting NSRTransformer Benchmark ---")
    print(f"Model parameters: {num_params / 1e6:.2f}M")
    print(f"Model dim: {model_dim}, Heads: {n_heads}, Encoder Blocks: {n_isab}, Decoder Layers: {n_dec_layers}")
    print(f"Batch size: {batch_size}, Input set size: {set_size}, Target seq len: {tgt_seq_len}")
    print(f"Checkpointing: {'Enabled' if USE_CHECKPOINTING else 'Disabled'}")

    # Warm-up iterations
    print("Running warm-up iterations...")
    for _ in tqdm(range(5)):
        with torch.autocast(device_type=device.type, dtype=amp_dtype):
            logits = model(src_set, tgt_tokens)
            loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    # Benchmark loop
    num_iterations = 50
    model.train()
    start_time = time.time()

    for i in tqdm(range(num_iterations)):
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type=device.type, dtype=amp_dtype):
            logits = model(src_set, tgt_tokens)
            loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end_time = time.time()

    total_time = end_time - start_time
    avg_time_per_iter = total_time / num_iterations
    throughput = (batch_size * num_iterations) / total_time

    print("\n--- Benchmark Results ---")
    print(f"Total time for {num_iterations} iterations: {total_time:.2f} seconds")
    print(f"Average time per iteration: {avg_time_per_iter * 1000:.2f} ms")
    print(f"Throughput: {throughput:.2f} samples/sec")


    # Test inference speed
    model.eval()
    with torch.no_grad():
        # Warm-up
        for _ in tqdm(range(5)):
            with torch.autocast(device_type=device.type, dtype=amp_dtype):
                logits = model(src_set, tgt_tokens)
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        start_time = time.time()
        for _ in tqdm(range(num_iterations)):
            with torch.autocast(device_type=device.type, dtype=amp_dtype):
                logits = model(src_set, tgt_tokens)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        end_time = time.time()

        total_time = end_time - start_time
        avg_time_per_iter = total_time / num_iterations
        throughput = (batch_size * num_iterations) / total_time

        print("\n--- Inference Benchmark Results ---")
        print(f"Total time for {num_iterations} iterations: {total_time:.2f} seconds")
        print(f"Average time per iteration: {avg_time_per_iter * 1000:.2f} ms")
        print(f"Throughput: {throughput:.2f} samples/sec")


Using device: cuda, AMP dtype: torch.bfloat16

--- Starting NSRTransformer Benchmark ---
Model parameters: 113.36M
Model dim: 512, Heads: 8, Encoder Blocks: 8, Decoder Layers: 8
Batch size: 128, Input set size: 512, Target seq len: 256
Checkpointing: Enabled
Running warm-up iterations...


100%|██████████| 5/5 [00:03<00:00,  1.27it/s]
100%|██████████| 50/50 [00:38<00:00,  1.29it/s]



--- Benchmark Results ---
Total time for 50 iterations: 39.25 seconds
Average time per iteration: 784.94 ms
Throughput: 163.07 samples/sec

--- Inference Benchmark Results ---
Total time for 50 iterations: 8.42 seconds
Average time per iteration: 168.49 ms
Throughput: 759.70 samples/sec
