In [None]:
import time
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import *
from transformers.cache_utils import *
from fasta import * 
from typing import *


class ModifiedLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper, with modifications."""

    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional["Cache"] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            if hasattr(past_key_value, "get_key") and hasattr(past_key_value, "get_value"):
                previous_key = past_key_value.get_key(self.layer_idx)
                previous_value = past_key_value.get_value(self.layer_idx)
                if previous_key is not None:
                    key_states = torch.cat([previous_key, key_states], dim=2)
                if previous_value is not None:
                    value_states = torch.cat([previous_value, value_states], dim=2)
            else:
                raise AttributeError("Cache object lacks required methods 'get_key' and 'get_value'.")

        key_states = key_states.transpose(-1, -2)

        # Use fasta_attn for efficient computation
        try:
            attention_output = fasta_attn(
                query_states.flatten(0, 1),  # Combine batch and head dimensions
                key_states.flatten(0, 1),
                block_size=128
            )
        except Exception as e:
            raise RuntimeError(f"Error in fasta_attn: {e}")

        attention_output = attention_output.view(*input_shape, -1).contiguous()

        # Weighted sum of value states
        attn_output = torch.einsum("bhls,bhld->bhld", attention_output, value_states)

        # Merge heads and project back to hidden size
        attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1)
        attn_output = self.o_proj(attn_output)

        # Update cache
        present_key_value = past_key_value or Cache()
        present_key_value.add(key_states, value_states)

        return attn_output, None, present_key_value


# Replace LlamaAttention with ModifiedLlamaAttention in the model
def replace_attention_modules(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Module) and hasattr(module, "q_proj"):
            parent_name = name.rsplit(".", 1)[0]
            parent = dict(model.named_modules())[parent_name]
            setattr(parent, name.split(".")[-1], ModifiedLlamaAttention(module.config, module.layer_idx))

# Function to compute perplexity
def compute_perplexity(model, tokenizer, text, device):
    inputs = tokenizer(text, return_tensors="pt").to(device)  # Move input tensors to GPU
    model.to(device)  # Move model to GPU
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs["input_ids"])
    loss = outputs.loss
    perplexity = torch.exp(loss)
    return perplexity.item()

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

# Clone the model to create a modified version
modified_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
replace_attention_modules(modified_model)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Sample text
text = "The quick brown fox jumps over the lazy dog."

# Measure original model performance
start_time = time.time()
original_perplexity = compute_perplexity(model, tokenizer, text, device)
original_time = time.time() - start_time

# Measure modified model performance
start_time = time.time()
modified_perplexity = compute_perplexity(modified_model, tokenizer, text, device)
modified_time = time.time() - start_time

# Print results
print(f"Original Perplexity: {original_perplexity}, Time: {original_time:.4f}s")
print(f"Modified Perplexity: {modified_perplexity}, Time: {modified_time:.4f}s")

