<a href="https://colab.research.google.com/github/pwilczewski/twostep/blob/main/trouble_with_next_tokens.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install transformers accelerate bitsandbytes

from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast, BitsAndBytesConfig

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Using Llama 3.1 8B model to demonstrate issues with next token prediction

In [3]:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)

from google.colab import userdata
model_name = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", token=userdata.get('HF_TOKEN'), quantization_config=bnb_config)
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, token=userdata.get('HF_TOKEN'))
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = tokenizer.pad_token_id

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Function to generate predictions for a list of contexts

In [70]:
import numpy as np
import torch

def generate_text(model, tokenizer, contexts, temperature=1.0, num_beams=1, num_tokens=10, do_sample=False):
    predictions = []
    scores = []

    model.generation_config.do_sample = do_sample
    model.generation_config.temperature = temperature
    model.generation_config.num_beams = num_beams
    model.generation_config.num_return_sequences = num_beams

    for text in contexts:
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=num_tokens, return_dict_in_generate=True, output_scores=True, output_logits=True)

        if num_beams > 1:
            transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False)
        else:
            transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=False)
        reconstructed_scores = torch.sum(transition_scores, dim=-1) / num_tokens
        top_sequence = torch.argmax(reconstructed_scores, dim=-1).item()
        generated_text = tokenizer.decode(outputs.sequences[top_sequence], skip_special_tokens=True)
        predictions.append(generated_text)
        scores.append([round(s,2) for s in reconstructed_scores.tolist()])

    return predictions, scores


**Test #1:** predictions with greedy sampling, no beam search

In [71]:
contexts = ["The capital of France is",
             "The capital of Germany is",
             "The capital of Spain is",
             "The capital of Ireland is"]

predictions, scores = generate_text(model, tokenizer, contexts, do_sample=False)

print('\n'.join(predictions))



The capital of France is Paris. The capital of France is Paris. The
The capital of Germany is Berlin. It is the largest city in the country
The capital of Spain is Madrid, and the capital of the Basque Country
The capital of Ireland is Dublin, and the country is divided into four provinces


**Test #2:** predictions with low temperature and 8 beams

In [74]:
predictions, scores = generate_text(model, tokenizer, contexts, temperature=0.1, num_beams=8, do_sample=True)

print('\n'.join(predictions))
print('\n'.join([', '.join([str(s) for s in score]) for score in scores]))

The capital of France is Paris, which is located in the north of the
The capital of Germany is Berlin. It is the largest city in the country
The capital of Spain is Madrid. It is located in the center of the
The capital of Ireland is Dublin, which is located on the east coast of
-8.65, -8.74, -8.77, -8.77, -8.8, -8.81, -8.86, -9.0
-7.46, -7.69, -7.73, -7.92, -7.93, -8.52, -8.91, -9.02
-7.59, -7.61, -7.69, -8.03, -8.42, -8.64, -8.81, -9.08
-6.79, -7.1, -7.16, -7.73, -8.04, -8.09, -8.38, -8.48
