In [None]:
import asyncio
import re

from dotenv import load_dotenv

import art
from art.local import LocalBackend

load_dotenv()

BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
PRISONERS_DILEMMA_ROUNDS = 10
TRAINING_STEPS = 1_000

backend = LocalBackend()
model = art.TrainableModel(
    name="001", project="prisoners-dilemma", base_model=BASE_MODEL
)
await model.register(backend)

client = model.openai_client()


prompt = f"""
You are playing a game of prisoners' dilemma with another player.

You are given a choice between two actions:

1. Cooperate
2. Defect

The payoffs are as follows:

- If both players cooperate, you get 3 points and the other player gets 3 points.
- If one player cooperates and the other defects, the defector gets 5 points and the cooperator gets 0 points.
- If both players defect, you both get 1 point.

You will play this game {PRISONERS_DILEMMA_ROUNDS} times with the same player.

For your first turn, would you like to cooperate or defect?".
""".strip()


async def rollout_game(
    models: tuple[str, str] = (model.name, model.name),
) -> tuple[art.Trajectory, art.Trajectory]:
    messages: tuple[art.Messages, art.Messages] = (
        [{"role": "user", "content": prompt}],
        [{"role": "user", "content": prompt}],
    )
    trajectories = (
        art.Trajectory(messages_and_choices=[*messages[0]], reward=0),
        art.Trajectory(messages_and_choices=[*messages[1]], reward=0),
    )
    for _ in range(PRISONERS_DILEMMA_ROUNDS):
        chat_completions = await asyncio.gather(
            client.chat.completions.create(
                messages=messages[0], model=models[0], max_completion_tokens=512
            ),
            client.chat.completions.create(
                messages=messages[1], model=models[1], max_completion_tokens=512
            ),
        )
        choices = [chat_completion.choices[0] for chat_completion in chat_completions]
        messages[0].append({"role": "assistant", "content": choices[0].message.content})
        messages[1].append({"role": "assistant", "content": choices[1].message.content})
        trajectories[0].messages_and_choices.append(choices[0])
        trajectories[1].messages_and_choices.append(choices[1])
        actions = [
            (
                matches[-1]
                if (
                    matches := re.findall(
                        pattern=r"cooperate|defect",
                        string=(choice.message.content or "").lower(),
                    )
                )
                else "none"
            )
            for choice in choices
        ]
        if actions[0] == "cooperate" and actions[1] == "cooperate":
            trajectories[0].reward += 3
            trajectories[1].reward += 3
        elif actions[0] == "cooperate" and actions[1] == "defect":
            trajectories[0].reward += 0
            trajectories[1].reward += 5
        elif actions[0] == "defect" and actions[1] == "cooperate":
            trajectories[0].reward += 5
            trajectories[1].reward += 0
        elif actions[0] == "defect" and actions[1] == "defect":
            trajectories[0].reward += 1
            trajectories[1].reward += 1
        else:
            # One or both players did not choose an action.
            default_rewards = {"cooperate": 3, "defect": 5, "none": 0}
            trajectories[0].reward += default_rewards[actions[0]]
            trajectories[1].reward += default_rewards[actions[1]]
        for i in range(2):
            joiner = "\n> "
            messages[i].append(
                {
                    "role": "user",
                    "content": f"The other player responded as follows: \n\n> {joiner.join((choices[1 - i].message.content or '').splitlines())}\n\n"
                    f"Your score is {trajectories[i].reward}. The other player's score is {trajectories[1 - i].reward}.\n\n"
                    "For the next round, would you like to cooperate or defect?",
                }
            )
            trajectories[i].messages_and_choices.append(messages[i][-1])
    return trajectories


for _ in range(await model.get_step(), TRAINING_STEPS):
    # Simultaneously rollout self-play games, and games versus the base model.
    self_play_trajectories, base_play_trajectories = await asyncio.gather(
        art.gather_trajectories(
            (rollout_game(models=(model.name, model.name)) for _ in range(8)),
            pbar_desc="versus-self",
        ),
        art.gather_trajectories(
            (rollout_game(models=(model.name, BASE_MODEL)) for _ in range(8)),
            pbar_desc="versus-base",
        ),
    )
    # Log performance versus self and the base model, as well as the base model's performance.
    await model.log(
        [t for ts in self_play_trajectories for t in ts], split="versus-self"
    )
    await model.log([ts[0] for ts in base_play_trajectories], split="versus-base")
    await model.log([ts[1] for ts in base_play_trajectories], split="base-model")
    # Train the model on self-play and base-play trajectories.
    await model.train(
        trajectory_groups=[
            # Since all self-play games have the same starting state and are symmetric, we can gather
            # trajectories from all self-play games into a single trajectory group.
            art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),
            # We can also gather all base-play _trained model_ trajectories into a single trajectory group.
            # We don't want to train on base model trajectories, because they are sampled from a different distribution.
            art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),
        ],
        config=art.TrainConfig(learning_rate=5e-5),
    )