In [1]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import numpy as np
from collections import Counter
import ipywidgets as widgets
from IPython.display import display, clear_output

# Function to load models (same as before)
def load_models():
    BASE_MODEL_DIR = "google/gemma-2b"
    LORA_ADAPTER_DIR = "models/gemma-math-reasoning-lora-final"
    
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, trust_remote_code=True)
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_DIR,
        trust_remote_code=True,
        device_map="auto"
    )
    generator = PeftModel.from_pretrained(base_model, LORA_ADAPTER_DIR)
    generator = generator.merge_and_unload()
    generator.eval()
    
    scorer = AutoModelForCausalLM.from_pretrained(
        "models/checkpoint-2000",
        trust_remote_code=True,
        device_map="auto"
    )
    scorer.eval()
    
    return generator, scorer, tokenizer

# Enhanced function to extract final answer from reasoning
def extract_answer(reasoning):
    try:
        # Look for "The answer is" or "Therefore," or similar phrases
        markers = ["The answer is", "Therefore,", "Thus,", "So,", "Finally,"]
        for marker in markers:
            if marker in reasoning:
                answer_part = reasoning.split(marker)[-1].strip()
                # Extract the first number or numeric expression
                import re
                numbers = re.findall(r'[-+]?\d*\.?\d+', answer_part)
                if numbers:
                    return float(numbers[0])
        return None
    except:
        return None

# Enhanced generation function with answer extraction
def generate_answers(model, tokenizer, question, num_samples=3, max_length=1024):
    answers = []
    answer_values = []
    prompt = f"Question: {question}\nLet's solve this step by step:"
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    
    for _ in range(num_samples):
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            num_return_sequences=1
        )
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answers.append(answer)
        answer_values.append(extract_answer(answer))
    
    return answers, answer_values

# Scoring function (same as before)
def score_answer(model, tokenizer, question, answer):
    prompt = f"Question: {question}\nReasoning: {answer}\nRate how likely this reasoning is correct:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    final_logits = logits[0, -1]
    score = torch.softmax(final_logits, dim=0).max().item()
    return score

# New function to compute self-consistency score
def get_self_consistency_score(answer_values):
    if not answer_values or all(v is None for v in answer_values):
        return None, 0
    
    # Filter out None values
    valid_answers = [v for v in answer_values if v is not None]
    if not valid_answers:
        return None, 0
    
    # Count occurrences of each answer
    counter = Counter(valid_answers)
    most_common_answer, count = counter.most_common(1)[0]
    consistency_score = count / len(answer_values)  # Normalize by total samples
    
    return most_common_answer, consistency_score

# Load models
print("Loading models... This may take a few minutes.")
generator, scorer, tokenizer = load_models()
print("Models loaded successfully!")

# Enhanced interface widgets
question_input = widgets.Textarea(
    placeholder="Type your question here...",
    description="Question:",
    layout=widgets.Layout(width='100%', height='100px')
)

num_samples_input = widgets.IntSlider(
    value=5, min=3, max=20, step=1, description="Samples:"
)

consistency_weight = widgets.FloatSlider(
    value=0.5, min=0, max=1, step=0.1, 
    description="Consistency Weight:",
    tooltip="Weight for self-consistency vs verifier score (0=verifier only, 1=consistency only)"
)

generate_button = widgets.Button(description="Generate Solutions")
output_area = widgets.Output()

# Enhanced button click handler
def on_generate_clicked(b):
    output_area.clear_output()
    question = question_input.value
    num_samples = num_samples_input.value
    weight = consistency_weight.value
    
    if not question.strip():
        with output_area:
            print("Please enter a valid question!")
        return
    
    with output_area:
        print("Generating solutions... Please wait.")
    
    # Generate solutions and extract numerical answers
    answers, answer_values = generate_answers(generator, tokenizer, question, num_samples)
    
    # Get self-consistency result
    consistent_answer, consistency_score = get_self_consistency_score(answer_values)
    
    # Get verifier scores
    verifier_scores = [score_answer(scorer, tokenizer, question, ans) for ans in answers]
    
    # Compute combined scores
    combined_scores = []
    for i, (ans_value, verifier_score) in enumerate(zip(answer_values, verifier_scores)):
        if ans_value is None:
            combined_score = (1 - weight) * verifier_score
        else:
            consistency_component = weight * (1.0 if ans_value == consistent_answer else 0.0)
            verifier_component = (1 - weight) * verifier_score
            combined_score = consistency_component + verifier_component
        combined_scores.append(combined_score)
    
    # Display results
    with output_area:
        print("\nGenerated Solutions:\n")
        for i, (ans, ans_value, v_score, c_score) in enumerate(zip(answers, answer_values, verifier_scores, combined_scores), 1):
            print(f"Solution {i}:")
            print(f"Extracted Answer: {ans_value}")
            print(f"Verifier Score: {v_score:.2f}")
            print(f"Combined Score: {c_score:.2f}")
            print(f"Reasoning:\n{ans}\n")
        
        print("\nSelf-Consistency Analysis:")
        print(f"Most consistent answer: {consistent_answer}")
        print(f"Consistency score: {consistency_score:.2f}")
        
        # Find best solution according to combined score
        best_idx = np.argmax(combined_scores)
        print(f"\nBest solution (combined scoring) is #{best_idx + 1}")

# Link button click to function
generate_button.on_click(on_generate_clicked)

# Display the interface
display(widgets.VBox([
    question_input,
    num_samples_input,
    consistency_weight,
    generate_button,
    output_area
]))

Loading models... This may take a few minutes.


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Models loaded successfully!


VBox(children=(Textarea(value='', description='Question:', layout=Layout(height='100px', width='100%'), placeh…

### Model Performance on a Question
#### Out-of-Domain Example: MultiArith Question

**Question**:  
There are 64 students trying out for the school's trivia teams. If 36 of them didn't get picked for the team and the rest were put into 4 groups, how many students would be in each group?  

---
The **ground truth** for this question is **7**. The model generates **5 solutions**, each with associated scores and correctness:

| Solution      | Score | Correctness |
|---------------|-------|-------------|
| **Solution 1** | 0.59  | Incorrect   |
| **Solution 2** | 0.61  | Incorrect   |
| **Solution 3** | 0.71  | Correct     |
| **Solution 4** | 0.69  | Correct     |
| **Solution 5** | 0.71  | Correct     |

---

**Key Observations**:
- Solutions 3, 4, and 5 are correct, achieving higher scores (≥ 0.69).
- Lower scores, such as 0.59 and 0.61, correspond to incorrect solutions.

**Potential Applications**:

- The score threshold of **~0.65** could serve as a good decision boundary.
- Multiple solutions with **high scores** increase confidence in the answer.
- The **model + verifier** combination shows promise for **real-world applications**.
