In [None]:
!pip install transformers torch accelerate bitsandbytes ipywidgets huggingface_hub

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.m

In [None]:

import torch
import torch.nn as nn
import math
from transformers import LlamaForCausalLM, AutoTokenizer, LlamaConfig, BitsAndBytesConfig
from transformers.models.llama.modeling_llama import LlamaAttention
import ipywidgets as widgets
from IPython.display import display, clear_output

# Step 1: Authenticate with Hugging Face
from huggingface_hub import login

# Replace 'your_huggingface_token' with your actual token
# You can get this from https://huggingface.co/settings/tokens
huggingface_token = "XXXXXXXXXX"  # Replace with your token
login(huggingface_token)

# Step 2: Hyperparameters for iRoPE
chunk_size = 2048  # Local attention chunk size (increased for long contexts)
alpha = 8192  # α for temperature scaling (scaled for 10M context)
beta = 0.1  # β for temperature scaling
gamma = 0.5  # For power-law scaling
scaling_type = "log"  # Scaling function type
max_seq_len = 16384  # Maximum sequence length per chunk (limited by memory)
simulated_context_length = 10_000_000  # Simulate 10M token context
rope_theta = 500000.0  # Default for LLaMA 3.2 (confirmed from model config)

# Step 3: Simplified RoPE (Rotary Position Embeddings) - Replicate LLaMA's RoPE
def apply_rotary_pos_emb(q, k, seq_len, head_dim, rope_theta=rope_theta):
    device = q.device
    position_ids = torch.arange(seq_len, device=device).unsqueeze(1)  # [seq_len, 1]
    indices = torch.arange(head_dim // 2, device=device)
    freqs = 1.0 / (rope_theta ** (2 * indices / head_dim))  # [head_dim // 2]
    angles = position_ids * freqs  # [seq_len, head_dim // 2]

    cos_angles = torch.cos(angles)
    sin_angles = torch.sin(angles)
    cos_angles = torch.cat([cos_angles, cos_angles], dim=-1)  # [seq_len, head_dim]
    sin_angles = torch.cat([sin_angles, sin_angles], dim=-1)  # [seq_len, head_dim]
    cos_angles = cos_angles.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim]
    sin_angles = sin_angles.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim]

    q_rot = q * cos_angles + k * sin_angles
    k_rot = k * cos_angles - q * sin_angles
    return q_rot, k_rot

# Step 4: Local Attention with RoPE (chunked attention)
class LocalAttentionWithRoPE(LlamaAttention):
    def __init__(self, config, layer_idx, chunk_size):
        super().__init__(config, layer_idx)
        self.chunk_size = chunk_size
        self.scale = (self.head_dim ** -0.5)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False, **kwargs):
        B, L, _ = hidden_states.shape

        # Compute Q, K, V using LLaMA's pretrained projections
        qkv = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
        q = qkv[0].view(B, L, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, L, head_dim]
        k = qkv[1].view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = qkv[2].view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Apply RoPE to Q and K
        q, k = apply_rotary_pos_emb(q, k, L, self.head_dim)

        # Expand k and v to match the number of query heads (for GQA)
        k = k.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)
        v = v.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)

        # Chunked attention
        attn_outputs = []
        for i in range(0, L, self.chunk_size):
            q_chunk = q[:, :, i:i+self.chunk_size, :]
            k_chunk = k[:, :, i:i+self.chunk_size, :]
            v_chunk = v[:, :, i:i+self.chunk_size, :]

            attn_scores = torch.matmul(q_chunk, k_chunk.transpose(-1, -2)) * self.scale
            if attention_mask is not None:
                attn_scores += attention_mask[:, :, i:i+self.chunk_size, i:i+self.chunk_size]
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_out = torch.matmul(attn_probs, v_chunk)  # [B, num_heads, chunk_size, head_dim]
            attn_outputs.append(attn_out)

        attn_out = torch.cat(attn_outputs, dim=2)  # [B, num_heads, L, head_dim]
        attn_out = attn_out.transpose(1, 2).reshape(B, L, -1)
        attn_out = self.o_proj(attn_out)

        if output_attentions:
            return attn_out, attn_scores
        return attn_out, None

# Step 5: Global Attention with Inference-Time Temperature Scaling (no position embeddings)
class GlobalAttentionWithTempScaling(nn.Module):
    def __init__(self, config, layer_idx, alpha, beta, gamma, scaling_type):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.scale = self.head_dim ** -0.5
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.scaling_type = scaling_type

        # Compute dimensions for GQA
        kv_dim = self.num_key_value_heads * self.head_dim

        # Reuse LLaMA's pretrained projections with correct dimensions for GQA
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
        self.k_proj = nn.Linear(config.hidden_size, kv_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(config.hidden_size, kv_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False, **kwargs):
        B, L, D = hidden_states.shape
        q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Expand k and v to match the number of query heads (for GQA)
        k = k.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)
        v = v.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)

        # Inference-time temperature scaling for q
        positions = torch.arange(L, device=hidden_states.device)
        if self.scaling_type == "log":
            scaling_factor = 1 + torch.log(torch.floor(positions / self.alpha) + 1) * self.beta
        elif self.scaling_type == "linear":
            scaling_factor = 1 + (positions / self.alpha) * self.beta
        elif self.scaling_type == "exp":
            scaling_factor = 1 + torch.exp(positions / self.alpha - 1) * self.beta
        elif self.scaling_type == "sigmoid":
            scaling_factor = 1 + torch.sigmoid(positions / self.alpha - 1) * self.beta
        elif self.scaling_type == "power":
            scaling_factor = 1 + (positions / self.alpha) ** self.gamma * self.beta
        else:
            raise ValueError(f"Unknown scaling type: {self.scaling_type}")

        q = q * scaling_factor.view(1, 1, L, 1)

        # Global attention (no position embeddings)
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        if attention_mask is not None:
            attn_scores += attention_mask
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, v)
        attn_out = attn_out.transpose(1, 2).reshape(B, L, D)
        attn_out = self.o_proj(attn_out)

        if output_attentions:
            return attn_out, attn_scores
        return attn_out, None

# Step 6: Modified LLaMA Layer with iRoPE
class LlamaLayerWithIRoPE(nn.Module):
    def __init__(self, config, layer_idx, chunk_size, alpha, beta, gamma, scaling_type):
        super().__init__()
        self.layer_idx = layer_idx
        self.use_local = layer_idx % 2 == 0  # Interleave local and global layers
        self.attn = LocalAttentionWithRoPE(config, layer_idx, chunk_size) if self.use_local else GlobalAttentionWithTempScaling(config, layer_idx, alpha, beta, gamma, scaling_type)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False, **kwargs):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        attn_output, attn_weights = self.attn(
            hidden_states, attention_mask, position_ids, output_attentions, use_cache, **kwargs
        )
        hidden_states = residual + attn_output

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_weights,)
        return outputs

# Step 7: Modified LLaMA Model with iRoPE
class LlamaWithIRoPE(LlamaForCausalLM):
    def __init__(self, config, chunk_size, alpha, beta, gamma, scaling_type):
        super().__init__(config)
        self.model.layers = nn.ModuleList([
            LlamaLayerWithIRoPE(config, layer_idx, chunk_size, alpha, beta, gamma, scaling_type)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.model.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

# Step 8: Load the pretrained model with 4-bit quantization
model_name = "meta-llama/Llama-3.2-3B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load the original model first
original_model = LlamaForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16
)

# Create the iRoPE model with the same config
config = original_model.config
model = LlamaWithIRoPE(
    config,
    chunk_size=chunk_size,
    alpha=alpha,
    beta=beta,
    gamma=gamma,
    scaling_type=scaling_type
)

# Transfer weights from the original model to the iRoPE model
for layer_idx, irope_layer in enumerate(model.model.layers):
    original_layer = original_model.model.layers[layer_idx]
    # Copy attention weights
    if irope_layer.use_local:
        irope_layer.attn.q_proj.load_state_dict(original_layer.self_attn.q_proj.state_dict())
        irope_layer.attn.k_proj.load_state_dict(original_layer.self_attn.k_proj.state_dict())
        irope_layer.attn.v_proj.load_state_dict(original_layer.self_attn.v_proj.state_dict())
        irope_layer.attn.o_proj.load_state_dict(original_layer.self_attn.o_proj.state_dict())
    else:
        irope_layer.attn.q_proj.load_state_dict(original_layer.self_attn.q_proj.state_dict())
        irope_layer.attn.k_proj.load_state_dict(original_layer.self_attn.k_proj.state_dict())
        irope_layer.attn.v_proj.load_state_dict(original_layer.self_attn.v_proj.state_dict())
        irope_layer.attn.o_proj.load_state_dict(original_layer.self_attn.o_proj.state_dict())
    # Copy MLP and LayerNorm weights
    irope_layer.mlp.load_state_dict(original_layer.mlp.state_dict())
    irope_layer.input_layernorm.load_state_dict(original_layer.input_layernorm.state_dict())
    irope_layer.post_attention_layernorm.load_state_dict(original_layer.post_attention_layernorm.state_dict())

# Copy the final LayerNorm and LM head
model.model.norm.load_state_dict(original_model.model.norm.state_dict())
model.lm_head.load_state_dict(original_model.lm_head.state_dict())

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set pad token to EOS token

# Move the model to the correct device
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# Step 9: Simulate 10M token context by processing in chunks
def create_causal_attention_mask(batch_size, seq_length, device):
    # Create a causal mask: 0s for positions that can be attended to, -inf for positions that cannot
    mask = torch.triu(torch.ones(seq_length, seq_length, device=device) * float('-inf'), diagonal=1)
    mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]
    mask = mask.expand(batch_size, 1, seq_length, seq_length)  # [batch_size, 1, seq_len, seq_len]
    return mask

def process_long_context(input_text, model, tokenizer, max_seq_len, simulated_context_length):
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt", truncation=False)
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    # Simulate a 10M token context by repeating the input (for demonstration)
    num_repeats = max(1, simulated_context_length // input_ids.shape[1])
    input_ids = input_ids.repeat(1, num_repeats)
    attention_mask = attention_mask.repeat(1, num_repeats)
    total_length = input_ids.shape[1]

    # Process in chunks to handle memory constraints
    all_hidden_states = []
    for start in range(0, total_length, max_seq_len):
        end = min(start + max_seq_len, total_length)
        chunk_input_ids = input_ids[:, start:end]
        chunk_attention_mask = attention_mask[:, start:end]

        # Create a causal attention mask
        batch_size, chunk_seq_len = chunk_input_ids.shape
        causal_mask = create_causal_attention_mask(batch_size, chunk_seq_len, chunk_input_ids.device)

        # Combine padding mask with causal mask
        if chunk_attention_mask is not None:
            # Convert padding mask to [batch_size, 1, 1, seq_len]
            padding_mask = chunk_attention_mask.unsqueeze(1).unsqueeze(2).to(torch.float32)
            padding_mask = (1.0 - padding_mask) * float('-inf')  # 0s become -inf, 1s remain 0
            chunk_attention_mask = causal_mask + padding_mask  # Combine masks

        with torch.no_grad():
            outputs = model.model(
                input_ids=chunk_input_ids,
                attention_mask=chunk_attention_mask,
                output_hidden_states=True
            )
        hidden_states = outputs.hidden_states[-1]  # Last layer hidden states
        all_hidden_states.append(hidden_states)

    # Concatenate hidden states
    full_hidden_states = torch.cat(all_hidden_states, dim=1)

    # Generate output using the final hidden states (simplified)
    with torch.no_grad():
        logits = model.lm_head(full_hidden_states)
        predicted_ids = torch.argmax(logits[:, -1, :], dim=-1)
        output_ids = torch.cat([input_ids[:, :-1], predicted_ids.unsqueeze(1)], dim=1)

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Step 10: Create a simple UI with ipywidgets
input_box = widgets.Textarea(
    value="Tell me a story about a futuristic city.",
    placeholder="Type your input here...",
    description="Input:",
    layout={'width': '500px', 'height': '100px'}
)

output_box = widgets.Output()

button = widgets.Button(description="Generate", button_style="primary")

def on_button_clicked(b):
    with output_box:
        clear_output()
        print("Processing... (Simulating 10M token context)")
        input_text = input_box.value
        output_text = process_long_context(input_text, model, tokenizer, max_seq_len, simulated_context_length)
        print("Output:")
        print(output_text)

button.on_click(on_button_clicked)

# Display the UI
display(input_box)
display(button)
display(output_box)