In [2]:
import art
import asyncio
import json
import random
import re
from typing import TypedDict


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]


puzzles: list[TemporalCluePuzzle] = json.load(open("./data/temporal-clue/puzzles.json"))
val_puzzles = puzzles[:64]
test_puzzles = puzzles[64:128]
train_puzzles = puzzles[128:]
random.seed(42)
random.shuffle(train_puzzles)


model = await art.get_or_create_model(
    name="temporal-clue", base_model="Qwen/Qwen2.5-7B-Instruct"
)


async def rollout(puzzle: TemporalCluePuzzle) -> art.Trajectory:
    messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}]
    chat_completion = await art.client.chat.completions.create(
        messages=messages,
        model=model,
    )
    choice = chat_completion.choices[0]
    messages.append(choice)
    content = choice.message.content
    assert isinstance(content, str)
    num_correct = 0
    for key, value in puzzle["solution"].items():
        if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
            match = matches[-1]
            if match.strip().lower() == value.lower():
                num_correct += 1
    accuracy = num_correct / len(puzzle["solution"])
    return art.Trajectory(messages=messages, reward=accuracy)


stride = 32
for i in range(model.iteration, 1_000):
    val_groups, train_groups = await asyncio.gather(
        asyncio.gather(
            *(
                asyncio.gather(*(rollout(puzzle) for _ in range(2)))
                for puzzle in val_puzzles
            )
        ),
        asyncio.gather(
            *(
                asyncio.gather(*(rollout(puzzle) for _ in range(64)))
                for puzzle in train_puzzles[i * stride : (i + 1) * stride]
            )
        ),
    )
    _, _ = await asyncio.gather(
        model.save_eval(val_groups),
        model.tune(train_groups),
    )

'Hello from agent-reinforcement-training!'