In [1]:
from typing import Union

from datasets import Dataset, load_dataset
from peft import LoraConfig  # type: ignore
from pydantic import BaseModel, Field

import xverify as xv

In [2]:
class Tools(BaseModel):
    """
    Run a tool.
    """

    tool_use: xv.XMLToolUse[xv.calculator, xv.search] = Field(
        ..., description="The tool call to use"
    )


class FinalAnswer(BaseModel):
    """
    Return a final answer.
    """

    answer: int = Field(..., description="Final answer to the question")


class Reason_and_Act(BaseModel):
    scratchpad: str = Field(
        ...,
        description="Information from the Observation useful to answer the question",
    )
    reasoning: str = Field(
        ...,
        description="It describes your thoughts about the question you have been asked",
    )
    response: Tools | FinalAnswer
    # response: Union[Tools, FinalAnswer] = Field(..., description="Current step response")


def tool_response_func(model: Reason_and_Act) -> dict | None:
    return xv.run_tools(model.response)


guided_schema = xv.GuidedSchema(
    Reason_and_Act,
    schema="xml",
    tool_response_func=tool_response_func,
)

In [3]:
def format_prompt(
    prompt: str, system_prompt: str | None = None
) -> list[dict[str, str]]:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})
    return messages


def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

max_steps = 10

SYSTEM_PROMPT = f"""\
You are a helpful assistant, responding in XML structured output.

- Think step by step using the scratchpad and reasoning outputs. You have {max_steps - 1} steps to think before responding.
- Use the tools provided. DO NOT rely on your own knowledge when a tool is available to help you.
- Respond with a final answer only once your are absolutely sure you have the answer.

Respond with a XML object, following the schema below:

{guided_schema.doc}
"""

dataset: Dataset = load_dataset("openai/gsm8k", "main", split="train")  # type: ignore
dataset = dataset.map(lambda x: {
    'prompt': [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
})

In [None]:
def exact_answer_reward_func(completions, answer, **_) -> list[float]:
    """Reward function that checks if the final answer matches the expected answer."""

    def _check_answer(trajectory: list[dict[str, str]], answer: str) -> float:
        """Extract the last answer from a trajectory."""
        last_message = trajectory[-1]
        assert last_message["role"] == "assistant", "should be assistant"
        parsed: Reason_and_Act | None = guided_schema.parse(last_message["content"])  # type: ignore
        if parsed is None or not isinstance(parsed.response, FinalAnswer):
            return 0.0
        return 1.0 if str(parsed.response.answer) == answer else 0.0

    return [_check_answer(c, a) for c, a in zip(completions, answer)]


model_name = "Qwen/Qwen2.5-3B-Instruct"
model, tokenizer = xv.get_model_and_tokenizer(model_name)


run_name = "gsm8k-calculator-peft_" + model_name.split("/")[-1].lower()
training_args = xv.get_default_grpo_config(
    run_name,
    num_gpus=1,
    ### GRPO params ###
    learning_rate=5e-6,
    num_generations=6,
    per_device_train_batch_size=6,  # 1 prompts per batch
    gradient_accumulation_steps=1,  # 1 prompt per forward
    logging_steps=1,
    num_iterations=2,  # 1 on-policy + 1 off-policy
    max_steps=250,
    max_prompt_length=200,
    max_completion_length=512,
    vllm_gpu_memory_utilization=0.3,
    optim="paged_adamw_8bit",
    # checkpointing
    save_steps=50,
    save_only_model=False,
    save_total_limit=1,
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
)


trainer = xv.GRPOGuidedTrainer(
    guided_schema=guided_schema,
    model=model,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
    reward_funcs=[exact_answer_reward_func],
)
trainer.train()