In [1]:
import art
from art.local import LocalBackend
from dotenv import load_dotenv
from openpipe.client import AsyncOpenPipe
import random

from rollout import rollout

load_dotenv()

random.seed(42)

# Initialize the server
backend = LocalBackend()

In [None]:
# Declare the model
model = art.TrainableModel(
    name="013",
    project="2048",
    base_model="Qwen/Qwen2.5-3B-Instruct",
    # To run on a T4, we need to override some config defaults.
    _internal_config=art.dev.InternalModelConfig(
        init_args=art.dev.InitArgs(
            max_seq_length=8192,
        ),
        engine_args=art.dev.EngineArgs(
            enforce_eager=True,
            gpu_memory_utilization=0.8,
            num_scheduler_steps=1,
        ),
    ),
)
await backend._experimental_pull_from_s3(
    model,
    verbose=True,
)
await model.register(backend)

for i in range(await model.get_step(), 50):
    train_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(model, i, is_validation=False) for _ in range(4)
            )
            for _ in range(1)
        ),
        pbar_desc="gather",
        max_exceptions=1,
    )
    await model.delete_checkpoints()
    await backend._experimental_push_to_s3(
        model,
    )
    await model.train(
        train_groups,
        config=art.TrainConfig(learning_rate=3e-5),
        # Lowering the logprob_calculation_chunk_size is a memory saving measure
        # to allow longer sequences (up to 4096 tokens) to be processed on a T4.
        _config={"logprob_calculation_chunk_size": 8},
    )

In [None]:
import asyncio
import os


async def log_comparison_model(comparison_model: art.Model):
    trajectories = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(comparison_model, 0, is_validation=True) for _ in range(12)
            )
            for _ in range(1)
        ),
        pbar_desc=f"gather {comparison_model.name}",
        max_exceptions=1,
    )

    await comparison_model.log(
        trajectories,
        split="val",
    )

gpt_4o_mini = art.Model(
    name="gpt-4o-mini",
    project="2048",
    inference_model_name="gpt-4o-mini",
    inference_base_url="https://api.openai.com/v1",
    inference_api_key=os.getenv("OPENAI_API_KEY"),
)
await gpt_4o_mini.register(backend)

gpt_4o = art.Model(
    name="gpt-4o",
    project="2048",
    inference_model_name="gpt-4o",
    inference_base_url="https://api.openai.com/v1",
    inference_api_key=os.getenv("OPENAI_API_KEY"),
)
await gpt_4o.register(backend)

gpt_4_1 = art.Model(
    name="gpt-4.1",
    project="2048",
    inference_model_name="gpt-4.1",
    inference_base_url="https://api.openai.com/v1",
    inference_api_key=os.getenv("OPENAI_API_KEY"),
)
await gpt_4_1.register(backend)

await backend._experimental_push_to_s3(
    gpt_4o_mini,
)
await backend._experimental_push_to_s3(
    gpt_4o,
)
await backend._experimental_push_to_s3(
    gpt_4_1,
)


# Optional logging client
op_client = AsyncOpenPipe()

promises = []

for comparison_model in [gpt_4o_mini, gpt_4o, gpt_4_1]:
    promises.append(log_comparison_model(comparison_model))

await asyncio.gather(*promises)
