In [1]:
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/10623Project

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/10623Project


In [3]:

# Load the model and tokenizer
model_name = "google/gemma-7b"  # Instruction-tuned Gemma
hf_token = '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'
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             device_map="cuda:0",
                                             token=hf_token)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=3072, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((3072,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((3072,), eps=1e-06)
      )
    )
    (norm): GemmaRMSNorm((3072,)

In [6]:
# Load test data
with open("gsm8k_reasoning_test.jsonl", "r") as f:  # change to your actual file path
    test_data = [json.loads(line) for line in f]

# Step 1: Extract the first answer block (robust to repeated Q/A)
def extract_first_answer_block(text):
    parts = text.split("A:", 1)
    if len(parts) < 2:
        return text.strip()
    answer_section = parts[1].strip()
    stop_index = answer_section.find("Q:")
    if stop_index != -1:
        answer_section = answer_section[:stop_index].strip()
    return answer_section

# Step 2: Extract final answer
def extract_final_answer(text):
    match = re.search(r"####\s*([-\d.,]+)", text)
    if match:
        return match.group(1).replace(",", "")
    numbers = re.findall(r"[-+]?\d*\.?\d+", text)
    return numbers[-1] if numbers else None

# Step 3: Normalize answers
def normalize_answer(ans):
    if ans is None:
        return None
    ans = ans.replace(",", "").replace("$", "").strip()
    try:
        return str(int(float(ans)))
    except:
        return ans

# Inference in batches
batch_size = 16
results = []

for i in tqdm(range(0, len(test_data), batch_size)):
# for i in tqdm(range(0, 4, batch_size)):
    batch = test_data[i:i+batch_size]
    # prompts = []
    # for item in batch:
    #     question = item["prompt"].replace("Q:", "").replace("A:", "").strip()
    #     new_prompt = f"Q: {question}\n\nWrite the final answer like this: #### [answer].\n\nA:"
    #     prompts.append(new_prompt)

    prompts = [item["prompt"] for item in batch]
    # print("Prompts: ", prompts)
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
            temperature=0.0,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    for j, output in enumerate(decoded):
        prediction = extract_first_answer_block(output)
        pred_ans = extract_final_answer(prediction)
        true_ans = extract_final_answer(batch[j]["completion"])
        correct = normalize_answer(pred_ans) == normalize_answer(true_ans)
        # print(f"Question: {prompts[j]}")
        # print(f"Output: {output}")
        # print(f"Prediction: {prediction}")
        # print(f"True Answer: {true_ans}")
        # print(f"Predicted Answer: {pred_ans}")

        results.append({
            "question": batch[j]["prompt"],
            "prediction": prediction,
            "pred_ans": pred_ans,
            "true_ans": true_ans,
            "correct": correct
        })

# Accuracy report
accuracy = sum(r["correct"] for r in results) / len(results)
print(f"\n✅ Baseline Accuracy: {accuracy:.2%}")

100%|██████████| 83/83 [13:00<00:00,  9.41s/it]


✅ Baseline Accuracy: 16.53%



