In [1]:
import time
import math
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import *
from fasta import *
from typing import *

class ModifiedLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

        # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
        self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        attn_weights = fasta_attn(query_states, key_states.transpose(2, 3),128) / math.sqrt(self.head_dim)
        # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


# Replace LlamaAttention with ModifiedLlamaAttention in the model
def replace_attention_modules(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Module) and hasattr(module, "q_proj"):
            # Get parent module and replace the attention module
            parent_name = name.rsplit(".", 1)[0]
            parent = dict(model.named_modules())[parent_name]

            # Instantiate ModifiedLlamaAttention and copy weights
            modified_attention = ModifiedLlamaAttention(module.config, module.layer_idx)

            # Copy weights and biases for q_proj
            modified_attention.q_proj.weight.data = module.q_proj.weight.data.clone()
            if module.q_proj.bias is not None:
                modified_attention.q_proj.bias.data = module.q_proj.bias.data.clone()

            # Copy weights and biases for k_proj
            modified_attention.k_proj.weight.data = module.k_proj.weight.data.clone()
            if module.k_proj.bias is not None:
                modified_attention.k_proj.bias.data = module.k_proj.bias.data.clone()

            # Copy weights and biases for v_proj
            modified_attention.v_proj.weight.data = module.v_proj.weight.data.clone()
            if module.v_proj.bias is not None:
                modified_attention.v_proj.bias.data = module.v_proj.bias.data.clone()

            # Copy weights and biases for o_proj
            modified_attention.o_proj.weight.data = module.o_proj.weight.data.clone()
            if module.o_proj.bias is not None:
                modified_attention.o_proj.bias.data = module.o_proj.bias.data.clone()

            # Replace the module
            setattr(parent, name.split(".")[-1], modified_attention)

            print(f"Replaced {name} with ModifiedLlamaAttention")


# Function to compute perplexity
def compute_perplexity(model, tokenizer, text, device):
    """
    Computes the perplexity of the input text and returns the model's output.

    Args:
        model: The language model.
        tokenizer: The tokenizer for the model.
        text (str): The input text.
        device (torch.device): The device to run the computation on.

    Returns:
        tuple: (perplexity, model_output_text)
    """
    # Tokenize input text and move tensors to the specified device
    inputs = tokenizer(text, return_tensors="pt").to(device)  
    model.to(device)  # Move model to the specified device
    
    with torch.no_grad():
        # Forward pass with labels for computing loss
        outputs = model(**inputs, labels=inputs["input_ids"])
    
    # Compute perplexity from loss
    loss = outputs.loss
    perplexity = torch.exp(loss)
    
    # Generate output text using the model's logits
    logits = outputs.logits  # Shape: (batch_size, sequence_length, vocab_size)
    predicted_ids = torch.argmax(logits, dim=-1)  # Get token IDs with highest probability
    model_output_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    
    return perplexity.item(), model_output_text


# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

# Clone the model to create a modified version
modified_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
replace_attention_modules(modified_model)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Sample text
text = open("sample_data.txt","r").read()
# Measure original model performance
start_time = time.time()
original_perplexity,output = compute_perplexity(model, tokenizer, text, device)
original_time = time.time() - start_time

# Measure modified model performance
start_time = time.time()
modified_perplexity,output_modified = compute_perplexity(modified_model, tokenizer, text, device)
modified_time = time.time() - start_time

# Print results
print(f"Original Perplexity: {original_perplexity}, Time: {original_time:.4f}s Output: {output}")
print(f"Modified Perplexity: {modified_perplexity}, Time: {modified_time:.4f}s Output: {output_modified}")



AttributeError: 'NoneType' object has no attribute 'data'

In [None]:
output_modified