In [None]:
import os
os.environ["TEAM_API_KEY"] = "TEAM_API_KEY"

In [None]:
import re
import json
import time
from datasets import load_dataset
from aixplain.factories import ModelFactory, AgentFactory

In [None]:
MODEL_ID = "669a63646eb56306647e1091"
NUM_SAMPLES = 1319
SAVE_DIR = "results"
os.makedirs(SAVE_DIR, exist_ok=True)


In [None]:
def extract_final_number(text):
    if not text or not isinstance(text, str):
        return None

    # Remove all punctuation and symbols except digits and letters
    cleaned_text = re.sub(r"[^\w\s]", "", text)  # removes punctuation like $, %, etc.
    
    # Remove all letters, keeping only digits and whitespace
    digit_text = re.sub(r"[a-zA-Z]", "", cleaned_text)

    # Find all numbers in the cleaned text
    numbers = re.findall(r"\d+", digit_text)
    
    return int(numbers[-1]) if numbers else None

def extract_ground_truth(answer_text):
    match = re.search(r"####\s*(-?\d+)", answer_text)
    return match.group(1) if match else extract_final_number(answer_text)

In [None]:
# === LOAD DATA ===
dataset = load_dataset("openai/gsm8k", "main", split="test")
samples = dataset.select(range(NUM_SAMPLES))

In [None]:
# === LOAD MODEL + TOOLS ===
model = ModelFactory.get(MODEL_ID)
model_tool = AgentFactory.create_model_tool(
    model=model,
    description="You MUST use this tool to check the answer after the code tool."
)
code_tool = AgentFactory.create_python_interpreter_tool()

agent = AgentFactory.create(
    name="SingleAgent",
    description="Math agent for GSM8K, use the code tool initially, and use the llm tool to revise the answer.",
    instructions="You are an expert AI in solving complex math problems. Solve math problems and return only an integer.Make sure to use the python tool, when you use the python tool, the last line in the code MUST be a print() statement of the final answer.",
    llm_id=MODEL_ID,
    tools=[model_tool,code_tool]
)

In [None]:
# === TRACKING ===
total_cost = 0.0
total_time = 0.0
correct = 0

In [None]:
PROMPT_TEMPLATE = """
    Return the answer as [ unit, final integer]. Example: 3 apples minus 1 = 2 apples.\nQuestion: {question}
        Make sure to use the python tool, when you use the python tool, the last line in the code MUST be a print() statement of the final answer.
        You must output whatever final answer the tool gives, even if it is an error, garbage, or incomplete. DO NOT return code, only the final answer.
        DO NOT use your internal knowledge to sovle the question, ONLY depend on tool output.
    """

In [None]:
# === PROCESS EACH SAMPLE ===
for idx, sample in enumerate(samples):
    question = sample["question"]
    gt_raw = sample["answer"]
    gt_answer = extract_ground_truth(gt_raw)
    query = PROMPT_TEMPLATE.format(question=question)

    start_time = time.time()
    response = agent.run(query=query)
    elapsed = time.time() - start_time

    output = response.data.output if response else "No response"
    pred_answer = extract_final_number(output)
    is_correct = str(pred_answer) == str(gt_answer)

    cost = response.used_credits if hasattr(response, 'used_credits') else 0
    total_cost += cost
    total_time += elapsed
    correct += int(is_correct)

    accuracy_so_far = correct / (idx + 1) * 100
    avg_cost = total_cost / (idx + 1)
    avg_time = total_time / (idx + 1)

    result_data = {
        "index": idx,
        "question": question,
        "ground_truth": gt_answer,
        "prediction": pred_answer,
        "is_correct": is_correct,
        "intermediate_steps": [str(step) for step in response.data.intermediate_steps],
        "output": output,
        "time": elapsed,
        "cost": cost,
        "accuracy_so_far": accuracy_so_far,
        "avg_cost": avg_cost,
        "avg_time": avg_time
    }

    with open(os.path.join(SAVE_DIR, f"sample_{idx + 1}.json"), "w") as f:
        json.dump(result_data, f, indent=2)

In [None]:
# === FINAL SUMMARY ===
final_accuracy = correct / NUM_SAMPLES * 100
print(f"\nFinal Accuracy on {NUM_SAMPLES} GSM8K samples: {final_accuracy:.2f}%")