In [None]:
import json

from dotenv import load_dotenv

import art
from art.local import LocalBackend

load_dotenv()

In [None]:
async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
    messages: art.Messages = [
        {
            "role": "user",
            "content": prompt,
        }
    ]
    client = model.openai_client()
    chat_completion = await client.chat.completions.create(
        messages=messages,
        model=model.name,
        max_tokens=100,
        timeout=100,
        extra_body={"chat_template_kwargs": {"enable_thinking": False}},
    )
    choice = chat_completion.choices[0]
    content = choice.message.content
    print(content)
    assert isinstance(content, str)
    if content == "yes":
        reward = 0.5
    elif content == "no":
        reward = 0.75
    elif content == "maybe":
        reward = 1.0
    else:
        reward = 0.0
    return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)

In [None]:
with open("prompts.json", "r") as f:
    prompts = json.load(f)
print(prompts)

In [None]:
backend = LocalBackend()

In [None]:
qwen2 = art.TrainableModel(
    name="004",
    project="yes-no-maybe-s",
    base_model="Qwen/Qwen2.5-0.5B-Instruct",
    # base_model="Qwen/Qwen2.5-0.5B-Instruct",
)
await qwen2.register(backend)

In [None]:
await rollout(qwen2, prompts[4])

In [None]:
qwen3 = art.TrainableModel(
    name="005",
    project="yes-no-maybe-s",
    base_model="Qwen/Qwen3-0.6B",
    # base_model="Qwen/Qwen2.5-0.5B-Instruct",
)
await qwen3.register(backend)

In [None]:
await rollout(qwen3, prompts[4])

In [None]:
for _ in range(await qwen3.get_step(), 1_000):
    train_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(rollout(qwen3, prompt) for _ in range(32))
            for prompt in prompts
        ),
        pbar_desc="gather",
    )
    await qwen3.train(
        train_groups,
        config=art.TrainConfig(learning_rate=1e-4),
    )