In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
from dotenv import load_dotenv
import nest_asyncio
import os

# PREPARE ENVIRONMENT
os.environ["MODAL_IMAGE_BUILDER_VERSION"] = "2024.10"
os.environ["SWE_AGENT_CONFIG_DIR"] = "."
os.environ["SWE_AGENT_TOOLS_DIR"] = "tools"
os.environ["SWE_AGENT_TRAJECTORY_DIR"] = "trajectories"
os.environ["WEAVE_PRINT_CALL_LINK"] = "False"
os.environ["WEAVE_LOG_LEVEL"] = "CRITICAL"

os.makedirs("replays", exist_ok=True)
os.makedirs("trajectories", exist_ok=True)

load_dotenv()
nest_asyncio.apply()

In [None]:
import art
from art.local import LocalBackend
from rollout import ModelConfig
import torch

# INITIALIZE MODEL
backend = LocalBackend()
model = art.TrainableModel(
    name="025",
    project="sweagent",
    config=ModelConfig(
        max_input_tokens=40_960,
        system_prompt_suffix="\n/no_think",
        xml_function_calling=True,
    ),
    # inference_api_key="default",
    # inference_base_url="http://0.0.0.0:8000/v1",
    # inference_model_name="Qwen/Qwen3-32B",
    base_model="Qwen/Qwen3-32B",
    _internal_config=art.dev.InternalModelConfig(
        engine_args=art.dev.EngineArgs(
            tensor_parallel_size=torch.cuda.device_count(), gpu_memory_utilization=0.85
        ),
        torchtune_args=art.dev.TorchtuneArgs(model="qwen3_32b", model_type="QWEN3", async_weight_syncing=True),
    ),
)
await model.register(backend)

In [5]:
from sandboxes import terminate_sandboxes

# await terminate_sandboxes()

In [None]:
import weave

from instances import as_instances_iter, get_filtered_swe_smith_instances_df
from rollout import rollout

weave.init(project_name=model.project)
rollout = weave.op(rollout)

# TRAIN MODEL
instances = (
    get_filtered_swe_smith_instances_df()
    .sample(fraction=1.0, shuffle=True, seed=42)
    .pipe(as_instances_iter)
)

async for trajectory_groups in art.trajectory_group_batches(
    (
        art.TrajectoryGroup(rollout(model, instance) for _ in range(4))
        for instance in instances
    ),
    batch_size=4,
    max_concurrent_batches=3,
    skip_batches=await model.get_step(),
):
    await model.train(
        trajectory_groups,
        _config=art.dev.TrainConfig(allow_training_without_logprobs=True),
    )