In [7]:
# Initialize the secret for HuggingFace login
import os
try:
    from google.colab import userdata
    # We are in colab, so we should access it from userdata.get(...)
    assert userdata.get('HF_TOKEN'), 'Set up HuggingFace login secret properly in Colab!'
    print('Found HF_TOKEN in Colab secrets')
except ModuleNotFoundError:
    # Not in colab, so we have to setup the token manually reading from a file
    if os.getenv('HF_TOKEN'):
        print('Found HF_TOKEN in environment variables')
    else:
        # Read it from a file
        hf_token_file = '.hf_token'
        assert os.path.exists(hf_token_file), f'You must create a file in this working directory ({os.getcwd()}) called {hf_token_file}, containing the Huggingface personal secret access token'
        with open(hf_token_file, 'r') as f:
            os.environ['HF_TOKEN'] = f.read().strip()
            print('Found HF_TOKEN in file')

Found HF_TOKEN in Colab secrets


In [8]:
from dataclasses import dataclass
from collections.abc import Iterator
import seaborn as sns

In [9]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-3.2-1B-Instruct"

In [10]:
class LlamaPrompt:
  user_prompt: str
  system_prompt: str

  def __init__(self, user_prompt, system_prompt="You are a helpful AI assistant."):
    self.user_prompt = user_prompt
    self.system_prompt = system_prompt

  def __str__(self) -> str:
      # From: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#-instruct-model-prompt-
      return ''.join([
          "<|begin_of_text|>",
          f"<|start_header_id|>system<|end_header_id|>{self.system_prompt}<|eot_id|>",
          f"<|start_header_id|>user<|end_header_id|>{self.user_prompt}<|eot_id|>",
          "<|start_header_id|>assistant<|end_header_id|>"
      ])

In [11]:
@dataclass
class LlamaResponse:
    prompt: LlamaPrompt
    response: str

In [12]:
class LlamaInstruct:
    def __init__(self, model_name: str, model_args: dict = None, tokenizer_args: dict = None, pad_token: str = None):
        self.model_name = model_name
        self.model_args = model_args if model_args is not None else dict()
        self.tokenizer_args = tokenizer_args if tokenizer_args is not None else dict()

        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", **self.model_args)
        self.model.eval()
        self.device = self.model.device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', **self.tokenizer_args)
        self.pad_token = self.tokenizer.eos_token if pad_token is None else pad_token
        self.tokenizer.pad_token = self.pad_token

        self.assistant_header = self.tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", return_tensors="pt").to(self.device)


    def tokenize(self, prompts: str | LlamaPrompt | list[str | LlamaPrompt]) -> tuple[dict, list[LlamaPrompt]]:
        # Make prompts a list anyway
        if not isinstance(prompts, list):
            prompts = [ prompts ]

        # Convert all prompts to LlamaPrompt
        prompts = [ prompt if isinstance(prompt, LlamaPrompt) else LlamaPrompt(prompt) for prompt in prompts ]

        inputs = self.tokenizer([
            str(prompt) for prompt in prompts
        ], truncation=True, padding=True, return_tensors="pt").to(self.device)

        return inputs, prompts


    def generate(self, inputs: dict, generate_args: dict = None) -> Iterator[LlamaResponse]:
        generate_args = generate_args if generate_args is not None else dict()
        default_args = {
            "max_length": 100,
            "num_return_sequences": 1,
            "temperature": 0.1,
            "pad_token_id": self.tokenizer.pad_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
        }

        # Overwrite default_args with generate_args
        default_args.update(generate_args)

        return self.model.generate(
            **inputs,
            **default_args,
        )


    def extract_responses(self, input_ids: torch.Tensor, outputs: torch.Tensor, prompts: list[LlamaPrompt]) -> Iterator[LlamaResponse]:
        for input, output, prompt in zip(input_ids, outputs, prompts):
            # Remove the prompt from the output generated
            output = output[len(input):]

            # Remove another assistant_header, if present
            if torch.equal(output[:len(self.assistant_header)], self.assistant_header):
                output = output[len(self.assistant_header):]

            generated = self.tokenizer.decode(output, skip_special_tokens=True).strip()

            yield LlamaResponse(prompt, generated)


    def run(self, prompts: str | LlamaPrompt | list[str | LlamaPrompt], verbose: bool = False) -> Iterator[LlamaResponse]:
        # Optional logging function
        def _print(*args, **kwargs):
            if verbose:
                print(*args, **kwargs)

        inputs, prompts = self.tokenize(prompts)

        _print('Tokenized inputs:', inputs.input_ids.shape)
        _print('Last tokens:', inputs.input_ids[:, -1])

        outputs = self.generate(inputs)
        _print('Generated outputs:', outputs.shape)

        return self.extract_responses(inputs.input_ids, outputs, prompts)


    def _get_model_num_heads(self) -> int:
        return self.model.config.num_attention_heads

    def _get_model_hidden_layers(self) -> int:
        return self.model.config.num_hidden_layers


In [13]:
llama = LlamaInstruct(model_name, model_args={"attn_implementation": "eager"})
llama.device

device(type='cpu')

In [14]:
def pretty_print_output(output: LlamaResponse):
    print(f"\n\n==================================")
    print(output.prompt.user_prompt)
    print("=============")
    print(output.response)

In [15]:
# Sample prompts
prompts = [
    "Explain quantum computing in simple terms.",
    "Write a short poem about artificial intelligence.",
    "Describe the future of renewable energy.",
    "Discuss the impact of machine learning on healthcare."
]

for output in llama.run(prompts, verbose=True):
    pretty_print_output(output)

Tokenized inputs: torch.Size([4, 29])
Last tokens: tensor([128007, 128007, 128007, 128007])


KeyboardInterrupt: 

In [None]:
list(llama.model.children())[0].layers

In [None]:
class LlamaAttentionExtractor:
    def __init__(self, llama: LlamaInstruct):
        """
        Initialize Llama model with hooks to capture attention maps
        """
        self.llama = llama


    def extract_attention_maps(self, prompt, max_length=64, num_return_sequences=5):
        print('Max len:', max_length)
        print('Return sequences:', num_return_sequences)

        # Get the number of head
        print('Llama num_heads:', self.llama._get_model_num_heads())
        print('Llama hidden_layers:', self.llama._get_model_hidden_layers())

        inputs, _ = self.llama.tokenize(prompt)
        print(f'Inputs: {inputs.input_ids.shape}')
        outputs = self.llama.generate(
            inputs,
            generate_args={
                "max_length": max_length,
                "num_return_sequences": num_return_sequences,
                "output_attentions": True,
                "return_dict_in_generate": True,
            }
        )

        return outputs.attentions

# Jupyter Notebook Usage
extractor = LlamaAttentionExtractor(llama)
prompt = "Hello, how are you?"
attentions = extractor.extract_attention_maps(prompt)

for i, attn in enumerate(attentions):
    print(f'attentions[{i:>2d}][{len(attn)} items]: {type(attn[0])}[{attn[0].shape}]')

In [None]:
# Show the attentions on the input, for the first input sequence only, for the last hidden_layer
input_attn = attentions[0][0, :, :, :]
print(f'input_attn: {type(input_attn[0])}[{input_attn[0].shape}]')


# num_layers = len(attention_maps)
# print('Found maps:', num_layers)
# plt.figure(figsize=(15, 4 * num_layers))

# for i, attn_map in enumerate(attention_maps):
#     plt.subplot(num_layers, 1, i+1)
#     avg_attn = attn_map.mean(dim=0).numpy()
#     sns.heatmap(avg_attn, cmap='viridis', cbar=True)
#     plt.title(f'Layer {i+1} Attention Map')

# plt.tight_layout()
# plt.show()