In [22]:
import tensorflow as tf
import numpy as np
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer

# Load the GPT-2 tokenizer and TensorFlow model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = TFGPT2LMHeadModel.from_pretrained('gpt2')

# Define the allowed letters and get their corresponding token IDs.
allowed_letters = ['A', 'B', 'C', 'D']
allowed_token_ids = [tokenizer.encode(letter, add_prefix_space=False)[0] for letter in allowed_letters]

def create_prompt_with_content(question_obj):
    """
    Create a prompt that uses the provided context, question, and options.
    Instructs the model to output only one letter (A, B, C, or D) with no extra text.
    """
    content = question_obj.get("content", "")
    options = question_obj['options']
    
    # Map options to letters (A, B, C, D, etc.)
    option_map = {i: chr(65 + i) for i in range(len(options))}
    options_text = "\n".join([f"{option_map[i]}. {option}" for i, option in enumerate(options)])
    
    prompt = (
        f"Content: {content}\n\n"
        f"Question: {question_obj['question']}\n\n"
        f"Options:\n{options_text}\n\n"
        "Please choose the correct option by outputting only one letter (A, B, C, or D) with no extra text.\n"
        "Your Answer: "
    )
    return prompt

def generate_one_allowed_token(prompt, allowed_token_ids, temperature=0.7, do_sample=True):
    """
    Generate one token after the prompt, restricting selection to allowed_token_ids.
    If do_sample=True, sample from the distribution using the provided temperature.
    """
    # Tokenize the prompt
    input_ids = tokenizer.encode(prompt, return_tensors='tf')
    
    # Run the model to get logits with return_dict=True
    outputs = model(input_ids, return_dict=True)
    logits = outputs.logits  # shape: (batch_size, seq_length, vocab_size)
    
    # Get logits for the last token (the next-token logits)
    last_token_logits = logits[:, -1, :]  # shape: (1, vocab_size)
    last_token_logits_np = last_token_logits.numpy()  # convert to numpy array
    
    # Create a masked logits vector: set non-allowed tokens to a very low score.
    masked_logits = np.full(last_token_logits_np.shape, -1e9)
    for token_id in allowed_token_ids:
        masked_logits[0, token_id] = last_token_logits_np[0, token_id]
    
    if do_sample:
        # Apply temperature scaling
        scaled_logits = masked_logits / temperature
        
        # Compute probabilities using softmax
        exp_logits = np.exp(scaled_logits)
        probs = exp_logits / np.sum(exp_logits)
        
        # Sample one token from allowed tokens
        next_token_id = int(np.random.choice(len(probs[0]), p=probs[0]))
    else:
        # Deterministic: take the highest probability token
        next_token_id = int(np.argmax(masked_logits))
        
    return next_token_id

def evaluate_answer(question_obj, generated_letter):
    """
    Evaluate whether the generated letter corresponds to the correct answer.
    Maps the correct answer text to its corresponding letter and compares.
    """
    correct_answer_text = question_obj["correct_answer"]
    options = question_obj["options"]
    try:
        correct_index = options.index(correct_answer_text)
        correct_letter = chr(65 + correct_index)
    except ValueError:
        return False, None
    
    is_correct = (generated_letter == correct_letter)
    return is_correct, correct_letter

# Example question object
question_obj = {
    "id": "hp_004",
    "question": "What are the animal mascots of Gryffindor, Ravenclaw, Hufflepuff, and Slytherin, respectively?",
    "options": [
      "Lion, Snake, Rat, Cow",
      "Lion, Eagle, Snake, Badger",
      "Sheep, Pig, Snake, Cow",
      "Lion, Hawk, Snake, Otter"
    ],
    "correct_answer": "Lion, Eagle, Snake, Badger",
    "content": "Each house in Hogwarts has a corresponding animal: Gryffindor (Lion), Ravenclaw (Eagle), Hufflepuff (Badger), and Slytherin (Snake), as explained in 'Harry Potter and the Philosopher’s Stone'."
  }

# Create the prompt (with content)
prompt = create_prompt_with_content(question_obj)

# Generate one token restricted to the allowed letters
next_token_id = generate_one_allowed_token(prompt, allowed_token_ids)
answer_generated = tokenizer.decode([next_token_id]).strip()

# Evaluate the answer
is_correct, correct_letter = evaluate_answer(question_obj, answer_generated)

print("Prompt:\n", prompt)
print("Generated Answer:", answer_generated)
print("Evaluation:", "Correct" if is_correct else f"Incorrect (expected {correct_letter})")


All model checkpoint layers were used when initializing TFGPT2LMHeadModel.

All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at gpt2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.


Prompt:
 Content: Each house in Hogwarts has a corresponding animal: Gryffindor (Lion), Ravenclaw (Eagle), Hufflepuff (Badger), and Slytherin (Snake), as explained in 'Harry Potter and the Philosopher’s Stone'.

Question: What are the animal mascots of Gryffindor, Ravenclaw, Hufflepuff, and Slytherin, respectively?

Options:
A. Lion, Snake, Rat, Cow
B. Lion, Eagle, Snake, Badger
C. Sheep, Pig, Snake, Cow
D. Lion, Hawk, Snake, Otter

Please choose the correct option by outputting only one letter (A, B, C, or D) with no extra text.
Your Answer: 
Generated Answer: A
Evaluation: Incorrect (expected B)
