In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import openai
from dotenv import load_dotenv

import art
from art.local import LocalBackend

load_dotenv()

backend = LocalBackend()
model = art.TrainableModel(
    name="009",
    project="yes-no-maybe",
    base_model="Qwen/Qwen2.5-7B-Instruct",
    # _internal_config=art.dev.InternalModelConfig(
    #     _decouple_vllm_and_unsloth=True,
    #     engine_args=art.dev.EngineArgs(gpu_memory_utilization=0.7),
    # ),
)
await model.register(backend)


async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory:
    messages: art.Messages = [
        {
            "role": "user",
            "content": prompt,
        }
    ]
    chat_completion = await client.chat.completions.create(
        messages=messages, model=model.name, max_tokens=100, timeout=100
    )
    choice = chat_completion.choices[0]
    content = choice.message.content
    assert isinstance(content, str)
    if content == "yes":
        reward = 0.5
    elif content == "no":
        reward = 0.75
    elif content == "maybe":
        reward = 1.0
    else:
        reward = 0.0
    return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)


def with_quotes(w):
    return f"'{w}'"


prompts = [
    f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
    for prefix in ["respond", "just respond"]
    for use_quotes in [True, False]
    for words in [
        ["yes", "no", "maybe"],
        ["maybe", "yes", "no"],
        ["no", "yes", "maybe"],
        ["yes", "maybe", "no"],
        ["yes", "no"],
        ["maybe", "no"],
        ["no", "maybe"],
        ["no", "yes"],
        ["yes", "no"],
    ]
]

openai_client = model.openai_client()
for _ in range(await model.get_step(), 1_000):
    train_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(rollout(openai_client, prompt) for _ in range(32))
            for prompt in prompts
        ),
        pbar_desc="gather",
    )
    await model.train(
        train_groups,
        config=art.TrainConfig(learning_rate=1e-4),
        # _config=art.dev.TrainConfig(
        #     precalculate_logprobs=True,
        # ),
    )

Skipping tuning as there is no suitable data. This can happen when all the trajectories in the same group have the same reward and thus no advantage to train on.
Advanced step from 264 to 265 (no training occurred)


gather:   0%|          | 0/1152 [00:00<?, ?it/s]

Skipping tuning as there is no suitable data. This can happen when all the trajectories in the same group have the same reward and thus no advantage to train on.
Advanced step from 265 to 266 (no training occurred)


gather:   0%|          | 0/1152 [00:00<?, ?it/s]

Skipping tuning as there is no suitable data. This can happen when all the trajectories in the same group have the same reward and thus no advantage to train on.
Advanced step from 266 to 267 (no training occurred)


gather:   0%|          | 0/1152 [00:00<?, ?it/s]

Skipping tuning as there is no suitable data. This can happen when all the trajectories in the same group have the same reward and thus no advantage to train on.
Advanced step from 267 to 268 (no training occurred)


gather:   0%|          | 0/1152 [00:00<?, ?it/s]