<a href="https://colab.research.google.com/github/CodeAlgorilla/SpeedyInference/blob/feature%2Fcolab_notebook/OutputObserver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Created by: [Mostafa Elhoushi](https://huggingface.co/melhoushi)

## Install requirements

First, run the cells below to install the requirements:

In [None]:
!pip install git+https://github.com/huggingface/transformers.git
!pip install accelerate

In [None]:
!huggingface-cli login

Now, let's import some libraries and classes that we will need in this demo.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import pandas as pd
import numpy as np

pd.set_option('display.max_colwidth', 500)

orig_model = None
layerskip_model = None
orig_tokenizer = None
layerskip_tokenizer = None

print function for all outputs at each layer

In [None]:
def get_early_exit_predictions(hidden_states, lm_head, tokenizer):
  layer2text = dict()
  device = hidden_states[0][0].device
  for layer_idx in range(len(hidden_states[0])):
    output_ids = torch.empty((1,1), device=device, dtype=torch.int)
    for token_idx in range(len(hidden_states)):
      logits = lm_head(hidden_states[token_idx][layer_idx])
      probs = torch.nn.functional.softmax(logits, dim=-1)
      out = torch.argmax(probs, dim=-1)
      output_ids = torch.cat((output_ids, out), dim=-1)
    text = tokenizer.batch_decode(output_ids[0])
    layer2text[layer_idx] = text
  return layer2text

And let's set some default generation configuration for this demo:

In [None]:
generation_config = {
    "do_sample": False,
    "temperature": None,
    "top_p": None,
    "max_new_tokens": 12
}

Choose the origin model ad the layerskip model

In [None]:
orig_checkpoint = "meta-llama/Llama-3.2-1B"         # meta-llama/Llama-2-7b-hf    meta-llama/Llama-2-13b-hf    meta-llama/Meta-Llama-3-8B      meta-llama/Llama-3.2-1B      meta-llama/CodeLlama-7b-hf
layerskip_checkpoint = "facebook/layerskip-llama3.2-1B"  # facebook/layerskip-llama2-7B   facebook/layerskip-llama2-13B  facebook/layerskip-llama3-8B     facebook/layerskip-llama3.2-1B  facebook/layerskip-codellama-7B

Load original model

In [None]:
orig_model = AutoModelForCausalLM.from_pretrained(
    orig_checkpoint,
    torch_dtype=torch.float16,
    device_map="auto",
)
orig_tokenizer = AutoTokenizer.from_pretrained(
    orig_checkpoint
)

orig_model.generation_config.pad_token_id = orig_tokenizer.eos_token_id

In [None]:
print(orig_model)


Load layerskip model

In [None]:
layerskip_model = AutoModelForCausalLM.from_pretrained(
    layerskip_checkpoint,
    torch_dtype=torch.float16,
    device_map="auto",
)
layerskip_tokenizer = AutoTokenizer.from_pretrained(
    layerskip_checkpoint
)

layerskip_model.generation_config.pad_token_id = layerskip_tokenizer.eos_token_id

In [None]:
print(layerskip_model)

delete model

In [None]:
if orig_model is not None:
  del orig_model
  del orig_tokenizer
  orig_model = None
  orig_tokenizer = None
  torch.cuda.empty_cache()

if layerskip_model is not None:
  del layerskip_model
  del layerskip_tokenizer
  layerskip_model = None
  layerskip_tokenizer = None
  torch.cuda.empty_cache()

Let's create a prompt:

In [None]:
prompt = "Once upon a time"
code_prompt = ""

Generate tokens using origin model

In [None]:
orig_inputs = orig_tokenizer(prompt, return_tensors="pt").to(orig_model.device)


orig_outputs = orig_model.generate(
    **orig_inputs,
    **generation_config,
    return_dict_in_generate=True,
    output_hidden_states=True
)
#print(orig_outputs)

# Convert Output Token IDs to Output Text
orig_text = orig_tokenizer.decode(orig_outputs["sequences"][0], skip_special_tokens=True)
print(orig_text)

In [None]:

# check hidden states
hidden_states = orig_outputs["hidden_states"]
print("hidden_states:")

print(f"  type(hidden_states): {type(hidden_states)}")
print(f"  len(hidden_states): {len(hidden_states)}")

print(f"    type(hidden_states[0]): {type(hidden_states[0])}")
print(f"    len(hidden_states[0]): {len(hidden_states[0])}")

print(f"      type(hidden_states[0][0]): {type(hidden_states[0][0])}")
print(f"      hidden_states[0][0].shape: {hidden_states[0][0].shape}")
print(f"      hidden_states[1][0].shape: {hidden_states[1][0].shape}")

batch_size, input_seq_len = orig_inputs["input_ids"].shape
batch_size, total_seq_len = orig_outputs["sequences"].shape

prompt_len = input_seq_len
num_steps = total_seq_len - input_seq_len
emb_dim = orig_model.config.hidden_size
num_layers = len(orig_model.model.layers)

print(f"batch_size: {batch_size}\n"
      f"prompt_len: {prompt_len}\n"
      f"num_steps: {num_steps}\n"
      f"emb_dim: {emb_dim}\n"
      f"num_layers: {num_layers}")

assert(len(hidden_states) == num_steps)
assert(len(hidden_states[0]) == num_layers + 1) # add 1 to count embedding layer
# Tensors of step 0 process prompt
assert(hidden_states[0][0].shape == (batch_size, prompt_len, emb_dim))
# Tensors of each remaining step processes a single token
assert(hidden_states[1][0].shape == (batch_size, 1, emb_dim))

Now, let's print a table that shows the full predicted text when exiting at each layer in origin model:

In [None]:
batch_size, input_seq_len = orig_inputs["input_ids"].shape
batch_size, total_seq_len = orig_outputs["sequences"].shape

orig_layer_2_text = get_early_exit_predictions(
    orig_outputs["hidden_states"],
    orig_model.lm_head,
    orig_tokenizer
)

orig_df = pd.DataFrame.from_dict(orig_layer_2_text, orient="index", columns=np.arange(total_seq_len))
orig_df = orig_df.replace({"\n": r"\textbackslash n", "#": r"\#"}, regex=True)


In [None]:
orig_df

Export data to LaTex

In [None]:
orig_df_latex = orig_df.to_latex(escape=False)
orig_df_latex
with open('original_output_table.tex', 'w') as f:
    f.write(orig_df_latex)

Show the full predicted text when exiting at each layer in a layerskip model

In [None]:
layerskip_inputs = layerskip_tokenizer(prompt, return_tensors="pt").to(layerskip_model.device)

layerskip_outputs = layerskip_model.generate(
    **layerskip_inputs,
    **generation_config,
    return_dict_in_generate=True,
    output_hidden_states=True
)

layerskip_layer_2_text = get_early_exit_predictions(
    layerskip_outputs["hidden_states"],
    layerskip_model.lm_head,
    layerskip_tokenizer
)

layerskip_text = layerskip_tokenizer.decode(layerskip_outputs["sequences"][0], skip_special_tokens=True)
print(layerskip_text)

In [None]:
batch_size, input_seq_len = layerskip_inputs["input_ids"].shape
batch_size, total_seq_len = layerskip_outputs["sequences"].shape

layerskip_df = pd.DataFrame.from_dict(layerskip_layer_2_text, orient="index", columns=np.arange(total_seq_len))
#layerskip_df = layerskip_df.style.set_properties(**{'text-align': 'left'})

Print and export output

In [None]:
layerskip_df

In [None]:
layerskip_df = layerskip_df.replace({"\n": r"\textbackslash n", "#": r"\#"}, regex=True)
layer_skip_latex_code = layerskip_df.to_latex(escape=False)

layer_skip_latex_code
with open('layerskip_output_table.tex', 'w') as f:
    f.write(layer_skip_latex_code)