# Text Generation with GPT Models

This notebook demonstrates how to generate text using pre-trained GPT models from Hugging Face, with visualizations of token probabilities and various generation parameters.

## Overview
- We'll use HuggingFace's transformers library to load a GPT model
- The notebook shows how to control text generation using parameters like temperature, top-k, and top-p
- We include an option to visualize the token-by-token generation process and probabilities

## 1. Import Required Libraries

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

## 2. Define Text Generation Function

The function below handles text generation with various parameters:
- `temperature`: Controls randomness (lower = more deterministic)
- `top_k`: Limits token selection to the top k highest probability tokens
- `top_p`: Uses nucleus sampling to select tokens whose cumulative probability exceeds threshold p


In [None]:
def generate_text(
    prompt,
    model_name="gpt2",  # Using GPT-2 as a stand-in for GPT-3 since GPT-3 isn't directly available in HF
    max_length=100,
    temperature=0.7,
    top_k=50,
    top_p=0.9,
    num_return_sequences=1
):
    """
    Generate text using a GPT model from Hugging Face with temperature control
    and visualization of token probabilities.

    Args:
        prompt (str): The input text to start generation from
        model_name (str): Name of the model to use from Hugging Face
        max_length (int): Maximum length of the generated text
        temperature (float): Controls randomness. Lower means more deterministic
        top_k (int): Number of highest probability tokens to consider
        top_p (float): Cumulative probability cutoff for token selection
        num_return_sequences (int): Number of sequences to return

    Returns:
        list: Generated text sequences
    """
    # Load model and tokenizer
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids

    # Store original input length to track new tokens
    input_length = input_ids.shape[1]

    # Generate with step-by-step probability display
    print(f"\nGenerating text with temperature={temperature}, top_k={top_k}, top_p={top_p}")
    print(f"Prompt: {prompt}")
    print("\nGeneration Process:")

    # Start with the input sequence
    current_ids = input_ids.clone()
    generated_tokens = []

    # Generate one token at a time
    for _ in range(max_length - input_length):
        with torch.no_grad():
            # Get model outputs
            outputs = model(current_ids)
            next_token_logits = outputs.logits[:, -1, :]

            # Apply temperature
            next_token_logits = next_token_logits / temperature

            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = torch.topk(next_token_logits, k=top_k)[0][..., -1, None]
                next_token_logits[next_token_logits < indices_to_remove] = -float('Inf')

            # Apply top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift the indices to the right to keep the first token above the threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_token_logits[0, indices_to_remove] = -float('Inf')

            # Convert logits to probabilities
            probs = torch.softmax(next_token_logits, dim=-1)

            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)

            # Show top 5 probabilities for the generated token
            top_probs, top_indices = torch.topk(probs, k=5)
            top_probs = top_probs.squeeze().cpu().numpy()
            top_indices = top_indices.squeeze().cpu().numpy()

            next_token_text = tokenizer.decode(next_token[0])
            generated_tokens.append(next_token_text)

            print(f"\nGenerated token: '{next_token_text}'")
            print("Top 5 token probabilities:")
            for i in range(5):
                token = tokenizer.decode([top_indices[i]])
                print(f"  {token}: {top_probs[i]:.6f}")

            # Add to current sequence
            current_ids = torch.cat([current_ids, next_token], dim=1)

            # Stop if we generate an EOS token
            if next_token[0, 0].item() == tokenizer.eos_token_id:
                break

    # Show the complete generated text
    full_text = tokenizer.decode(current_ids[0], skip_special_tokens=True)
    print(f"\nFinal generated text: {full_text}")

Let's run our function with a sample prompt to see it in action.

In [None]:
prompt = "In a world where AI has become sentient,"
generate_text(
    prompt=prompt,
    temperature=0.7,
    max_length=50
)

### Temperature
- Controls the randomness of predictions
- Lower temperature (e.g., 0.2) makes the model more confident and deterministic
- Higher temperature (e.g., 1.5) makes output more random and creative
- At temperature=0, the model always picks the most likely next token

### Top-k Sampling
- Only considers the top k most likely next tokens
- Prevents the model from selecting highly improbable tokens
- Common values range from 20-50

### Top-p (Nucleus) Sampling
- Only considers the smallest set of tokens whose cumulative probability exceeds p
- Dynamically adjusts the number of tokens considered based on probability distribution
- Typical values range from 0.7-0.95

These parameters can be combined to achieve a good balance between coherence and creativity in the generated text.