In [8]:
import json, random, os
import torch
import pandas as pd
from tqdm import tqdm
import transformers

from constants import *
from power_samp_utils import format_prompt, AutoregressiveSampler, mcmc_power_samp

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
cache_dir = "/home/tim/xai-hackathon-12-06-25/assets/huggingface"

device = "cuda" if torch.cuda.is_available() else "cpu"
model_str = "Qwen/Qwen2.5-7B"

json_file = "/home/tim/xai-hackathon-12-06-25/assets/data/MATH500.json"

with open(json_file) as f:
    dataset = json.load(f)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_str, cache_dir=cache_dir, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

hf_model = transformers.AutoModelForCausalLM.from_pretrained(
    model_str, dtype="auto", device_map="auto", trust_remote_code=True
)
autoreg_sampler = AutoregressiveSampler(hf_model, tokenizer, device)

def run_qwen(dataset, start=0, end=100, cot=True, temp=0.25, mcmc_steps=10, max_new_tokens=128, seed=0):
    random.seed(seed)
    results = []
    for _, data in tqdm(list(enumerate(dataset[start:end])), desc="MATH"):
        res = {}
        question = data["prompt"]
        answer = data["answer"]
        res["question"] = question
        res["answer"] = answer

        input_text = format_prompt(question, "qwen", tokenizer, cot)
        inputs = tokenizer(input_text, return_tensors="pt", padding=False, truncation=False)
        inputs = {k: v.to(hf_model.device) for k, v in inputs.items()}
        mcmc_ids_, _, _, acc = mcmc_power_samp(
            autoreg_sampler, inputs["input_ids"][0].tolist(), temp, mcmc_steps, max_new_tokens=max_new_tokens
        )
        mcmc_ids = torch.tensor(mcmc_ids_, dtype=torch.long, device="cpu")
        mcmc_completion = tokenizer.decode(mcmc_ids, skip_special_tokens=True)
        res["mcmc_completion"] = mcmc_completion
        results.append(res)
        return res

    return pd.DataFrame(results)


res = run_qwen(dataset, start=0, end=1, cot=True, temp=0.25, max_new_tokens=1<<8, mcmc_steps=1)

print(res)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

MATH:   0%|          | 0/1 [00:00<?, ?it/s]

alpha: 4.0
max new tokens: 256, jump size: 16



100%|██████████| 1/1 [00:00<00:00, 16.24it/s]

100%|██████████| 1/1 [00:00<00:00, 10.88it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  1.65it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  2.44it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  4.85it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  2.22it/s]

[A
100%|██████████| 1/1 [00:01<00:00,  1.37s/it]

[A
100%|██████████| 1/1 [00:00<00:00,  1.29it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  3.40it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  1.10it/s]

[A
100%|██████████| 1/1 [00:00<00:00,  2.76it/s]

[A
100%|██████████| 1/1 [00:01<00:00,  1.79s/it]

[A
100%|██████████| 1/1 [00:01<00:00,  1.03s/it]

[A
100%|██████████| 1/1 [00:02<00:00,  2.43s/it]

[A
100%|██████████| 1/1 [00:02<00:00,  2.06s/it]
 88%|████████▊ | 14/16 [00:16<00:02,  1.17s/it]
MATH:   0%|          | 0/1 [00:16<?, ?it/s]

{'question': 'Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$', 'answer': '\\left( 3, \\frac{\\pi}{2} \\right)', 'mcmc_completion': 'Can you solve the following math problem? Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$ Please reason step by step, and put your final answer within \\boxed{{}}. To convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we need to find the values of $r$ and $\\theta$ that satisfy the equations $x = r \\cos \\theta$ and $y = r \\sin \\theta$.\nIn this case, $x = 0$ and $y = 3$, so we have $0 = r \\cos \\theta$ and $3 = r \\sin \\theta$.\nFrom the first equation, we can see that $r = 0$ or $\\cos \\theta = 0$.\nSince $r > 0$, we must have $\\cos \\theta = 0$, which means $\\theta = \\frac{\\pi}{2}$.\nS


