In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import art
import asyncio
from dotenv import load_dotenv
import json
import openai
import random
import re
from typing import TypedDict

load_dotenv()


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]


puzzles: list[TemporalCluePuzzle] = json.load(open("./data/temporal-clue/puzzles.json"))
val_puzzles = puzzles[:64]
test_puzzles = puzzles[64:128]
train_puzzles = puzzles[128:]
random.seed(42)
random.shuffle(train_puzzles)


api = art.UnslothAPI(wandb_project="agent-reinforcement-training")
model = await api.get_or_create_model(
    name="temporal-clue-unsloth-001",
    base_model="unsloth/Qwen2.5-14B-Instruct",
)


async def rollout(
    client: openai.AsyncOpenAI, puzzle: TemporalCluePuzzle
) -> art.Trajectory:
    messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}]
    chat_completion = await client.chat.completions.create(
        messages=messages, model=model.name
    )
    choice = chat_completion.choices[0]
    content = choice.message.content
    assert isinstance(content, str)
    num_correct = 0
    for key, value in puzzle["solution"].items():
        if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
            match = matches[-1]
            if match.strip().lower() == value.lower():
                num_correct += 1
    reward = acc = num_correct / len(puzzle["solution"])
    return art.Trajectory(
        messages_and_choices=[*messages, choice], reward=reward, metrics={"acc": acc}
    )


stride = 1
for i in range(await model.get_iteration(), 1_000):
    async with model.openai_client(
        estimated_completion_tokens=900, verbosity=2
    ) as openai_client:
        val_groups, train_groups = await asyncio.gather(
            art.gather_groups(
                (
                    (rollout(openai_client, puzzle) for _ in range(2))
                    for puzzle in val_puzzles[:32]
                ),
                pbar_desc="val",
                stream_chat_completions=8,
            ),
            art.gather_groups(
                (
                    (rollout(openai_client, puzzle) for _ in range(50))
                    for puzzle in train_puzzles[i * stride : (i + 1) * stride]
                ),
                pbar_desc="train",
            ),
        )
        break
    await model.save(val_groups)
    await model.clear_iterations()
    await model.tune(
        train_groups, config=art.TuneConfig(plot_tensors=True, verbosity=2)
    )

In [None]:
trainer = api._get_trainer(model)
packed_tensors = api._get_packed_tensors(model, train_groups, 8192, 2, True)

In [None]:
from art.unsloth.grpo import GRPO
import os
import torch

loss_fn = GRPO()
loss_fn._forward_chunk = torch.compile(
    loss_fn._forward_chunk,
    backend=os.environ.get("TORCH_COMPILE_BACKEND", "inductor"),
)

In [None]:
api._get_model_and_tokenizer(model)[1]

In [None]:
from art.unsloth.pack import PackedTensors
from transformers import PreTrainedModel


def compute_loss(
    model: PreTrainedModel,
    inputs: PackedTensors,
    return_outputs: bool = False,
    num_items_in_batch: int | None = None,
) -> torch.Tensor:
    # Assume the first token in the batch is the bos token
    bos_id = int(inputs["tokens"].view(-1)[0].item())

    # Create grouped causal mask
    batch_size, seq_len = inputs["tokens"].size()
    causal_mask = (
        torch.tril(
            torch.ones(
                seq_len, seq_len, dtype=torch.bool, device=trainer.accelerator.device
            )
        )
        .unsqueeze(0)
        .expand(batch_size, seq_len, seq_len)
    )
    group_mask = inputs["group_ids"].unsqueeze(2) == inputs["group_ids"].unsqueeze(1)
    parent_mask = inputs["parent_ids"].unsqueeze(2) == inputs["group_ids"].unsqueeze(1)
    mask = causal_mask & (group_mask | parent_mask)

    if not hasattr(trainer, "_autocast_dtype"):
        trainer._autocast_dtype = (
            torch.float16
            if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
            else torch.bfloat16
        )
        if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
            trainer._autocast_dtype = torch.float16

    with torch.amp.autocast_mode.autocast(
        device_type="cuda", dtype=trainer._autocast_dtype
    ):
        logits = model(input_ids=inputs["tokens"], attention_mask=mask).logits

    result = loss_fn.forward(
        logits=logits,
        tokens=inputs["tokens"],
        advantages=inputs["advantages"],
        logprobs=inputs["logprobs"],
        reference_logprobs=None,
        mask=inputs["assistant_mask"],
        weights=inputs["weights"],
        bos_id=bos_id,
    )
    return result.per_token().total_loss


async def train() -> None:
    _compute_loss = trainer.compute_loss
    trainer.compute_loss = compute_loss
    trainer.train()


task = asyncio.create_task(train())
await api._packed_tensors_queue.put(
    {key: tensor[: trainer.args.per_device_train_batch_size] for key, tensor in packed_tensors.items()}  # type: ignore
)
await asyncio.wait([task], timeout=5.0)
task.cancel()
task.result()