In [None]:
%%capture
%load_ext autoreload
%autoreload 2
%cd /home/ubuntu/projects/hyper-sloth

In [None]:
from fastcore.all import *
from speedy_utils.all import *
from llm_utils import *

#### Create a LLM model

In [None]:
import os

from vllm.lora.request import LoRARequest
from vllm import LLM, SamplingParams
import torch

if "llm" in locals():
    del llm  #
    torch.cuda.empty_cache()
    
llm = LLM(
    # model="outputs/lora/Qwen2.5-1.5B-Instruct-LORA-MATH-merge",
    model="./outputs/loras/qwen1.5-openr1-1card/checkpoint-2500-merged",
    tensor_parallel_size=1,
    task="generate",
    enforce_eager=True,
    dtype=torch.bfloat16,
    max_model_len=16384,
    enable_lora=True,
    # quantization="bitsandbytes", load_format="bitsandbytes",gpu_memory_utilization=0.95
)

tokenizer = llm.get_tokenizer()

from datasets import load_dataset

# Load the GSM8K dataset
gsm8k = load_dataset("gsm8k", "main")
test = gsm8k["test"]


In [None]:
%debug

In [None]:
# Prepare prompts for GSM8K evaluation
all_questions = [item["question"] for item in test][:100]
standardized_prompts = [
    tokenizer.apply_chat_template(
        [
            {
                "role": "user",
                "content": f"{question}\nSolve step by step and put your final numerical answer inside \\boxed{{}}",
            }
        ],
        tokenize=False,
        add_generation_prompt=True,
    )
    for question in all_questions
]

# Set sampling parameters for deterministic generation
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.95,
    top_k=64,
    max_tokens=10000,
)
# Generate responses for all questions
outputs = llm.generate(
    standardized_prompts,
    sampling_params,
    lora_request=LoRARequest(
        "math", 1, "./outputs/loras/qwen1.5-openr1/checkpoint-732/"
    ),
)
all_outputs = [output.outputs[0].text for output in outputs]

In [None]:
all_questions[0]

In [None]:
def get_final_output(response):
    try:
        return int(response.split("\\boxed{")[1].split("}")[0])
    except:
        return None

In [None]:
final_outputs = [get_final_output(response) for response in all_outputs]
accs = []
num_error = 0
for i, gt in enumerate(test):
    if i >= len(final_outputs):
        break
    pred = final_outputs[i]
    try:
        num = gt['answer'].split('####')[1]
        num = int(num)
        pred = int(pred)
        accs.append(num == pred)
    except:
        num_error += 1
        accs.append(0)
        pass

In [None]:
np.mean(accs), num_error/len(final_outputs)

In [None]:
print(all_outputs[0])