In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import asyncio
import json

from dotenv import load_dotenv
from openai.types.chat.chat_completion import ChatCompletion

import art
from art.local import LocalBackend

load_dotenv()

MODEL_NAME = "001"
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
TRAINING_STEPS = 1_000

model = art.TrainableModel(
    name=MODEL_NAME, project="rock-paper-tool-use", base_model=BASE_MODEL
)
await model.register(LocalBackend())
client = model.openai_client()


def get_tool_call_id_and_move(chat_completion: ChatCompletion) -> tuple[str, str]:
    tool_calls = chat_completion.choices[0].message.tool_calls
    if not tool_calls:
        return "n/a", "nothing"
    tool_call = tool_calls[0]
    try:
        return tool_call.id, json.loads(tool_call.function.arguments)["move"]
    except json.JSONDecodeError:
        return tool_call.id, "nothing"
    except KeyError:
        return tool_call.id, "nothing"


async def rollout() -> art.Trajectory:
    tools: art.Tools = [
        {
            "type": "function",
            "function": {
                "name": "play_move",
                "description": "Play a move in rock-paper-scissors",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "move": {
                            "type": "string",
                            "enum": ["rock", "paper", "scissors"],
                            "description": "The move to play",
                        }
                    },
                    "required": ["move"],
                },
            },
        }
    ]
    trajectories = [
        art.Trajectory(
            messages_and_choices=[
                {
                    "role": "system",
                    "content": "You are a rock-paper-scissors playing agent. Use the play_move function tool to declare your moves.",
                },
                {
                    "role": "user",
                    "content": "What will your first move be?",
                },
            ],
            tools=tools,
            reward=0,
            metrics={
                "num_rounds": 0,
                "rock": 0,
                "paper": 0,
                "scissors": 0,
                "nothing": 0,
            },
        )
        for _ in range(2)
    ]
    for _ in range(10):
        chat_completions = await asyncio.gather(
            *[
                client.chat.completions.create(
                    messages=trajectory.messages(),
                    model=model,
                    tools=tools,
                    max_completion_tokens=100,
                )
                for trajectory, model in zip(trajectories, (MODEL_NAME, BASE_MODEL))
            ]
        )
        for trajectory, chat_completion in zip(trajectories, chat_completions):
            trajectory.messages_and_choices.append(chat_completion.choices[0])
        (id0, move0), (id1, move1) = list(
            map(get_tool_call_id_and_move, chat_completions)
        )
        beats = {
            "rock": "scissors",
            "paper": "rock",
            "scissors": "paper",
            "nothing": None,
        }
        if beats[move0] == move1:
            trajectories[0].reward += 1
        elif beats[move1] == move0:
            trajectories[1].reward += 1
        for trajectory in trajectories:
            trajectory.metrics["num_rounds"] += 1
        trajectories[0].metrics[move0] += 1
        trajectories[1].metrics[move1] += 1
        if max(t.reward for t in trajectories) > 2:
            break
        trajectories[0].messages_and_choices.extend(
            (
                {
                    "role": "tool",
                    "tool_call_id": id0,
                    "content": f"The other player played {move1}.",
                },
                {
                    "role": "user",
                    "content": "What will your next move be?",
                },
            )
        )
        trajectories[1].messages_and_choices.extend(
            (
                {
                    "role": "tool",
                    "tool_call_id": id1,
                    "content": f"The other player played {move0}.",
                },
                {
                    "role": "user",
                    "content": "What will your next move be?",
                },
            )
        )
    return trajectories[0]


for i in range(await model.get_step(), TRAINING_STEPS):
    trajectories = await art.gather_trajectories(
        (rollout() for _ in range(64)), max_exceptions=64
    )
    await model.train(
        [art.TrajectoryGroup(trajectories)],
        config=art.TrainConfig(learning_rate=5e-5),
    )