# 7. Contrastive Search

Welcome to one of the most brilliant State-Of-The-Art (SOTA) deterministic decoding methods: **Contrastive Search**.

In notebooks 1-6, we explored the tension between **Determinism (Greedy, Beam)** and **Stochasticity (Temperature, Top-P)**.
- **Greedy** was too boring and trapped itself in infinite loops.
- **Top-P/Temperature** was wildly creative, but ran the risk of hallucination and incoherence.

In 2022, researchers asked: *Can we get the high coherence of Greedy Decoding, without ever falling into a repetitive loop, but also without relying on random dice rolls?*

The answer was Contrastive Search.

## How it works (The Math)

Contrastive Search evaluates the top `K` most likely next words. For each word, it calculates a final score based on two competing factors:

1. **Model Confidence:** How mathematically likely is this token? (We want this to be HIGH).
2. **Degeneration Penalty:** How mathematically similar is the definition/embedding of this token to the definitions of the tokens we have *already* written? (We want this to be LOW).

$$ \text{Score}(v) = (1 - \alpha) \times \text{Confidence}(v) - \alpha \times \text{Similarity}(v, \text{past}) $$

The hyperparameter $\alpha$ (Alpha) balances the two. If $\alpha = 0.0$, this is identical to Greedy Decoding. If $\alpha = 1.0$, the model cares *only* about avoiding repetition.

Because it uses **Cosine Similarity** on the deep neural states of the tokens, it understands synonyms! If you say "The large dog", it knows that "big" is mathematically similar to "large", and will penalize it, forcing the model to pick a more interesting word to keep the story moving forward.

In [5]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
model_id = "Qwen/Qwen2.5-0.5B"

print(f"Loading {model_id} on {device}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model.eval()
print("Ready!")

Loading Qwen/Qwen2.5-0.5B on mps...


Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

Ready!


### HuggingFace Implementation

HuggingFace has actually built Contrastive Search natively into `model.generate()`. To use it, you must provide exactly two parameters:
1. `penalty_alpha` (> 0.0)
2. `top_k` (> 1)

Let's see it in action against a prompt that usually destroys Greedy decoding.

In [6]:
prompt = "A cat is a cat is a"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

print("--- Greedy Decoding (alpha=0.0) ---")
greedy_out = model.generate(**inputs, max_new_tokens=40, do_sample=False, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(greedy_out[0]))

print("\n--- Contrastive Search (alpha=0.6, top_k=4) ---")
# Notice we are NOT using do_sample=True! Contrastive search is deterministic.
contrast_out = model.generate(
    **inputs, 
    max_new_tokens=40, 
    penalty_alpha=0.6, 
    top_k=4,
    pad_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)
print(tokenizer.decode(contrast_out[0]))

--- Greedy Decoding (alpha=0.0) ---
A cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat is a cat

--- Contrastive Search (alpha=0.6, top_k=4) ---


generate.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/transformers-community/contrastive-search:
- custom_generate/generate.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Passing `generation_config` together with generation-related arguments=({'penalty_alpha', 'top_k', 'max_new_tokens', 'cache_implementation', 'pad_token_id'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.
Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
An assistant model is provided, using a dynamic cache instead of a cache of type='dynamic_full'.


A cat is a cat is a mammal.
This justifies what answer for what question? Q & A:
Question: Is a cat a mammal?

Answer: Yes, a cat is a mammal.
You are an AI


### ðŸ”¬ Experimentation Ideas

1. **Sweep Alpha Values:** 
   * *Run generation with `penalty_alpha` set to 0.1, 0.5, and 0.9. Watch how the text changes from repetitive to almost overly-verbose as it desperately avoids similar ideas.*
2. **Combine with Top-K:**
   * *What happens if `top_k = 2` vs `top_k = 50`? Contrastive search only evaluates the K most likely tokens. If K is too high, does it pick a terrible, grammatically incorrect word just because it is highly dissimilar to the context?*
3. **Compare to Repetition Penalty (Notebook 5):**
   * *Contrastive search penalizes deep semantic similarity (hidden states). Repetition Penalty penalizes exact string matches (token IDs). Which one produces better creative writing?*

In [7]:
print("=== Experiment 1: Sweep Alpha Values ===")
# We use a narrative prompt to see how creatively it avoids repeating itself
prompt = "The spaceship landed on the alien planet and the crew"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Alpha = 0.1 (Very little penalty) vs 0.9 (Massive penalty for similarity)
for alpha in [0.1, 0.5, 0.9]:
    print(f"\n--- Alpha = {alpha} (Top-K = 4) ---")
    out = model.generate(
        **inputs, 
        max_new_tokens=40, 
        penalty_alpha=alpha, 
        top_k=4, 
        pad_token_id=tokenizer.eos_token_id,
        trust_remote_code=True
    )
    print(tokenizer.decode(out[0]))

print("\n\n=== Experiment 2: Combine with Top-K ===")
print("Observe what happens when Top-K is huge. The model might pick a grammatically terrible word just because it is highly mathematically dissimilar to the context.")
for k_val in [2, 5, 50]:
    print(f"\n--- Top-K = {k_val} (Alpha = 0.6) ---")
    out = model.generate(
        **inputs, 
        max_new_tokens=40, 
        penalty_alpha=0.6, 
        top_k=k_val, 
        pad_token_id=tokenizer.eos_token_id,
        trust_remote_code=True
    )
    print(tokenizer.decode(out[0]))

print("\n\n=== Experiment 3: Contrastive Search vs Repetition Penalty ===")
creative_prompt = "The old wizard walked up to the magical glowing orb and"
creative_inputs = tokenizer(creative_prompt, return_tensors="pt").to(device)

print("\n--- Contrastive Search (Semantic similarity penalty) ---")
contrast_out = model.generate(
    **creative_inputs, 
    max_new_tokens=50, 
    penalty_alpha=0.6, 
    top_k=4, 
    pad_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)
print(tokenizer.decode(contrast_out[0]))

print("\n--- Repetition Penalty (Exact token match penalty) ---")
# To make it a fair fight for creative writing, we give the repetition penalty some random sampling
rep_out = model.generate(
    **creative_inputs, 
    max_new_tokens=50, 
    repetition_penalty=1.5,
    do_sample=True,
    temperature=0.8,
    top_k=50,
    pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(rep_out[0]))


=== Experiment 1: Sweep Alpha Values ===

--- Alpha = 0.1 (Top-K = 4) ---


Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The spaceship landed on the alien planet and the crew decided to count the number of alien species. They found that there were 120 alien species in total. If they counted 30 alien species in the first hour and 20 more

--- Alpha = 0.5 (Top-K = 4) ---


Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The spaceship landed on the alien planet and the crew decided to count the number of stars in the night sky. They found that there were 120 stars visible from the spaceship. If each star is represented by a point on a coordinate grid,

--- Alpha = 0.9 (Top-K = 4) ---


Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The spaceship landed on the alien planet and the crew was divided into two teams to explore a new area. Team 1 consisted 1/3 of the crew, which is 40% of the total crew. How many people are part of


=== Experiment 2: Combine with Top-K ===
Observe what happens when Top-K is huge. The model might pick a grammatically terrible word just because it is highly mathematically dissimilar to the context.

--- Top-K = 2 (Alpha = 0.6) ---


Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The spaceship landed on the alien planet and the crew decided to count the number of aliens they saw. They saw 12 aliens on the first day and 15 aliens on the second day. How many aliens did they see in total over the

--- Top-K = 5 (Alpha = 0.6) ---


Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The spaceship landed on the alien planet and the crew was divided into two teams. Team $A$ consisted of 10 scientists and 5 mathematicians, while Team $B$ consisted of 8 scientists and 6 mathematicians. After a

--- Top-K = 50 (Alpha = 0.6) ---


Both `max_new_tokens` (=40) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The spaceship landed on the alien planet and the crew went ashore.
Generate a new sentence that is, on a scale from 0 to 5, a 4 in textual similarity to the above sentence altough it is nonsensical or fl


=== Experiment 3: Contrastive Search vs Repetition Penalty ===

--- Contrastive Search (Semantic similarity penalty) ---


Both `max_new_tokens` (=50) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


The old wizard walked up to the magical glowing orb and said "I am the source of all magic. I can turn any object into anything else, and I can create a portal to any other world." The wizard then disappeared into the orb, and a portal opened up in the sky.  Given the

--- Repetition Penalty (Exact token match penalty) ---
The old wizard walked up to the magical glowing orb and cast it on a tiny, silver sphere.
In his absence he heard two young elves whispering about their new home in this realm of magic.

[img src="https://i.stack.imgur.com/kZ8rR.jpg" alt="" border="_


# Mannual Implementation

In [10]:
import torch
import torch.nn.functional as F

print("=== Manual Implementation of Contrastive Search ===")

# Hyperparameters
alpha = 0.6
k = 4
max_tokens = 40

prompt = "A cat is a cat is a"

# FIX: We must explicitly extract .input_ids from the tokenizer output
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

print(f"Prompt: {prompt}")
print("Generating: ", end="")

for step in range(max_tokens):
    with torch.no_grad():
        # output_hidden_states=True is the secret to contrastive search!
        outputs = model(input_ids, output_hidden_states=True)
    
    # 1. Get the raw logits and convert to probabilities
    next_token_logits = outputs.logits[0, -1, :]
    next_token_probs = F.softmax(next_token_logits, dim=-1)
    
    # 2. Find the Top-K most likely candidates (Model Confidence)
    top_k_probs, top_k_indices = torch.topk(next_token_probs, k)
    
    # We need the past hidden states to compare against
    # outputs.hidden_states is a tuple of all layers. We want the very last layer [-1]
    # Shape: [batch, sequence_length, hidden_dimension]
    past_hidden_states = outputs.hidden_states[-1][0] 
    
    best_score = -float('Inf')
    best_token = None
    
    # 3. Evaluate each of the K candidates
    for i in range(k):
        candidate_token = top_k_indices[i].unsqueeze(0).unsqueeze(0)
        candidate_prob = top_k_probs[i].item()
        
        # To get the hidden state of the candidate, we append it 
        # and do a quick forward pass just for this token
        test_input = torch.cat([input_ids, candidate_token], dim=-1)
        with torch.no_grad():
            test_outputs = model(test_input, output_hidden_states=True)
        
        # The hidden meaning of our candidate token 
        candidate_hidden_state = test_outputs.hidden_states[-1][0, -1, :]
        
        # 4. Calculate Cosine Similarity against ALL past tokens
        # We want to find the single past token that is MOST similar to our candidate
        similarities = []
        for past_token_state in past_hidden_states:
            sim = F.cosine_similarity(candidate_hidden_state, past_token_state, dim=0)
            similarities.append(sim.item())
            
        max_similarity = max(similarities)
        
        # 5. THE CONTRASTIVE SEARCH EQUATION
        score = (1.0 - alpha) * candidate_prob - (alpha * max_similarity)
        
        if score > best_score:
            best_score = score
            best_token = candidate_token
            
    # Append the winning token and continue the loop!
    input_ids = torch.cat([input_ids, best_token], dim=-1)
    print(tokenizer.decode(best_token[0]), end="", flush=True)

print("\n\nDone! Notice how our manual implementation perfectly avoids the 'cat is a' loop.")


=== Manual Implementation of Contrastive Search ===
Prompt: A cat is a cat is a
Generating:  mammal.
This justifies what answer for what question? Q & A: Question: What is a mammal?

Answer: Mammals are a group of animals that have a common characteristic,

Done! Notice how our manual implementation perfectly avoids the 'cat is a' loop.
