In [1]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Function to load models
def load_models():
    BASE_MODEL_DIR = "google/gemma-2b"
    LORA_ADAPTER_DIR = "models/gemma-math-reasoning-lora-final"
    
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, trust_remote_code=True)
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_DIR,
        trust_remote_code=True,
        device_map="auto"
    )
    generator = PeftModel.from_pretrained(base_model, LORA_ADAPTER_DIR)
    generator = generator.merge_and_unload()
    generator.eval()
    
    scorer = AutoModelForCausalLM.from_pretrained(
        "models/checkpoint-2000",
        trust_remote_code=True,
        device_map="auto"
    )
    scorer.eval()
    
    return generator, scorer, tokenizer

# Function to generate answers
def generate_answers(model, tokenizer, question, num_samples=3, max_length=1024):
    answers = []
    prompt = f"Question: {question}\nLet's solve this step by step:"
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    
    for _ in range(num_samples):
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            num_return_sequences=1
        )
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answers.append(answer)
    return answers

# Function to score an answer
def score_answer(model, tokenizer, question, answer):
    prompt = f"Question: {question}\nReasoning: {answer}\nRate how likely this reasoning is correct:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    final_logits = logits[0, -1]
    score = torch.softmax(final_logits, dim=0).max().item()
    return score

# Load models (may take a while)
print("Loading models... This may take a few minutes.")
generator, scorer, tokenizer = load_models()
print("Models loaded successfully!")

# Widgets for the interface
question_input = widgets.Textarea(
    placeholder="Type your question here...",
    description="Question:",
    layout=widgets.Layout(width='100%', height='100px')
)

num_samples_input = widgets.IntSlider(
    value=3, min=1, max=10, step=1, description="Samples:"
)

generate_button = widgets.Button(description="Generate Solutions")
output_area = widgets.Output()

# Function to handle button click
def on_generate_clicked(b):
    output_area.clear_output()
    question = question_input.value
    num_samples = num_samples_input.value
    
    if not question.strip():
        with output_area:
            print("Please enter a valid question!")
        return
    
    with output_area:
        print("Generating solutions... Please wait.")
    
    # Generate solutions
    answers = generate_answers(generator, tokenizer, question, num_samples)
    
    # Display solutions and scores
    with output_area:
        print("\nGenerated Solutions:\n")
        for i, ans in enumerate(answers, 1):
            score = score_answer(scorer, tokenizer, question, ans)
            print(f"Solution {i} (Score: {score:.2f}):\n{ans}\n")

# Link button click to function
generate_button.on_click(on_generate_clicked)

# Display the interface
display(widgets.VBox([
    question_input,
    num_samples_input,
    generate_button,
    output_area
]))


Loading models... This may take a few minutes.


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Models loaded successfully!


VBox(children=(Textarea(value='', description='Question:', layout=Layout(height='100px', width='100%'), placeh…

### Model Performance on a Question

The **ground truth** for this question is **7**. The model generates **5 solutions**, each with associated scores and correctness:

| Solution      | Score | Correctness |
|---------------|-------|-------------|
| **Solution 1** | 0.59  | Incorrect   |
| **Solution 2** | 0.61  | Incorrect   |
| **Solution 3** | 0.71  | Correct     |
| **Solution 4** | 0.69  | Correct     |
| **Solution 5** | 0.71  | Correct     |

---

**Key Observations**:
- Solutions 3, 4, and 5 are correct, achieving higher scores (≥ 0.69).
- Lower scores, such as 0.59 and 0.61, correspond to incorrect solutions.