In [1]:
!pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
import torch
import torch.nn as nn
import math

# Hyperparameters
dim = 64  # Model dimension (small for prototype)
num_heads = 4  # Number of attention heads
chunk_size = 128  # Local attention chunk size (e.g., 8K in the real model)
total_seq_len = 512  # Total sequence length for this prototype
alpha = 128  # α for temperature scaling (scaled down from 8K for prototype)
beta = 0.1  # β for temperature scaling
num_layers = 4  # Total layers (2 local, 2 global, interleaved)

# Simplified RoPE (Rotary Position Embeddings)
def apply_rotary_pos_emb(q, k, seq_len, head_dim):
    # Generate frequencies for rotary embeddings
    theta = torch.arange(head_dim // 2, device=q.device) / (head_dim // 2)
    theta = 10000 ** (-2 * theta / head_dim)  # [head_dim // 2]
    positions = torch.arange(seq_len, device=q.device).unsqueeze(1)  # [seq_len, 1]
    angles = positions * theta.unsqueeze(0)  # [seq_len, head_dim // 2]

    # Compute cos and sin for rotary embeddings
    cos_angles = torch.cos(angles)  # [seq_len, head_dim // 2]
    sin_angles = torch.sin(angles)  # [seq_len, head_dim // 2]

    # Duplicate to match head_dim (e.g., [seq_len, 8] -> [seq_len, 16])
    cos_angles = torch.cat([cos_angles, cos_angles], dim=-1)  # [seq_len, head_dim]
    sin_angles = torch.cat([sin_angles, sin_angles], dim=-1)  # [seq_len, head_dim]

    # Adjust shapes for broadcasting: [seq_len, head_dim] -> [1, 1, seq_len, head_dim]
    cos_angles = cos_angles.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim]
    sin_angles = sin_angles.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim]

    # Apply rotation to q and k
    q_rot = q * cos_angles + k * sin_angles  # [batch, heads, seq_len, head_dim]
    k_rot = k * cos_angles - q * sin_angles  # [batch, heads, seq_len, head_dim]
    return q_rot, k_rot

# Local Attention with RoPE
class LocalAttention(nn.Module):
    def __init__(self, dim, num_heads, chunk_size):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.chunk_size = chunk_size
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)  # Query, Key, Value projections
        self.proj = nn.Linear(dim, dim)  # Output projection

    def forward(self, x):
        B, L, D = x.shape  # Batch, Length, Dimension
        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, heads, L, head_dim]

        # Apply RoPE to q and k
        q, k = apply_rotary_pos_emb(q, k, L, self.head_dim)

        # Chunked attention (local attention within chunks)
        num_chunks = L // self.chunk_size
        attn_outputs = []
        for i in range(0, L, self.chunk_size):
            q_chunk = q[:, :, i:i+self.chunk_size, :]
            k_chunk = k[:, :, i:i+self.chunk_size, :]
            v_chunk = v[:, :, i:i+self.chunk_size, :]

            # Standard scaled dot-product attention within chunk
            attn_scores = torch.matmul(q_chunk, k_chunk.transpose(-1, -2)) * self.scale
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_out = torch.matmul(attn_probs, v_chunk)  # [B, heads, chunk_size, head_dim]
            attn_outputs.append(attn_out)

        # Concatenate chunks
        attn_out = torch.cat(attn_outputs, dim=2)  # [B, heads, L, head_dim]
        attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, L, D)
        return self.proj(attn_out)

# Global Attention with Inference-Time Temperature Scaling
class GlobalAttention(nn.Module):
    def __init__(self, dim, num_heads, alpha, beta):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.alpha = alpha
        self.beta = beta

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, L, D = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Inference-time temperature scaling for q
        positions = torch.arange(L, device=x.device)
        scaling_factor = 1 + torch.log(torch.floor(positions / self.alpha) + 1) * self.beta
        q = q * scaling_factor.view(1, 1, L, 1)  # Apply scaling to q

        # Global attention (no position embeddings)
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, v)
        attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, L, D)
        return self.proj(attn_out)

# Feed-Forward Network (FFN)
class FFN(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

# iRoPE Transformer Layer
class iRoPELayer(nn.Module):
    def __init__(self, dim, num_heads, chunk_size, alpha, beta, use_local=True):
        super().__init__()
        self.use_local = use_local
        self.attn = LocalAttention(dim, num_heads, chunk_size) if use_local else GlobalAttention(dim, num_heads, alpha, beta)
        self.ffn = FFN(dim, dim * 4)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # Residual connection
        x = x + self.ffn(self.norm2(x))  # Residual connection
        return x

# Full iRoPE Model
class iRoPEModel(nn.Module):
    def __init__(self, dim, num_heads, chunk_size, alpha, beta, num_layers):
        super().__init__()
        # Interleave local and global layers
        self.layers = nn.ModuleList([
            iRoPELayer(dim, num_heads, chunk_size, alpha, beta, use_local=(i % 2 == 0))
            for i in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Instantiate and Test the Model
model = iRoPEModel(dim, num_heads, chunk_size, alpha, beta, num_layers)
x = torch.randn(1, total_seq_len, dim)  # Batch size 1, sequence length 512, dimension 64
output = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([1, 512, 64])
Output shape: torch.Size([1, 512, 64])
