In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [3]:
from dotenv import load_dotenv
import nest_asyncio
import os

# PREPARE ENVIRONMENT
os.environ["MODAL_IMAGE_BUILDER_VERSION"] = "2024.10"
os.environ["SWE_AGENT_CONFIG_DIR"] = "."
os.environ["SWE_AGENT_TOOLS_DIR"] = "tools"
os.environ["SWE_AGENT_TRAJECTORY_DIR"] = "trajectories"
os.environ["WEAVE_PRINT_CALL_LINK"] = "False"
os.environ["WEAVE_LOG_LEVEL"] = "CRITICAL"

os.makedirs("replays", exist_ok=True)
os.makedirs("trajectories", exist_ok=True)

load_dotenv()
nest_asyncio.apply()

In [None]:
import art
from art.local import LocalBackend
from rollout import ModelConfig
import torch

# INITIALIZE MODEL
backend = LocalBackend()
model = art.TrainableModel(
    name="031",
    project="sweagent",
    config=ModelConfig(
        max_input_tokens=40_960,
        system_prompt_suffix="\n/no_think",
        xml_function_calling=True,
    ),
    base_model="Qwen/Qwen3-32B",
    _internal_config=art.dev.InternalModelConfig(
        engine_args=art.dev.EngineArgs(
            tensor_parallel_size=torch.cuda.device_count(), gpu_memory_utilization=0.85
        ),
        torchtune_args=art.dev.TorchtuneArgs(
            model="qwen3_32b", model_type="QWEN3", async_weight_syncing=True
        ),
    ),
)
await model.register(backend)

In [5]:
from sandboxes import terminate_sandboxes

await terminate_sandboxes()

In [6]:
# import weave

from instances import as_instances_iter, get_filtered_swe_smith_instances_df
from rollout import rollout

# weave.init(project_name=model.project)
# rollout = weave.op(rollout)

# TRAIN MODEL
instances = list(
    get_filtered_swe_smith_instances_df()
    .sample(fraction=1.0, shuffle=True, seed=42)
    .pipe(as_instances_iter)
)

In [None]:
GROUPS_PER_BATCH = 16
ROLLOUTS_PER_GROUP = 16
# for i in range(await model.get_step(), 1_000):
for i in [0] * 1_000:
    trajectory_groups = await art.gather_trajectory_groups(
        (
            art.TrajectoryGroup(
                rollout(model, instance, reward_power=1.33, timeout=60 * 15)
                for _ in range(ROLLOUTS_PER_GROUP)
            )
            for instance in instances[
                i * GROUPS_PER_BATCH % len(instances) : (i + 1)
                * GROUPS_PER_BATCH
                % len(instances)
            ]
        )
    )
    await model.train(
        trajectory_groups,
        # config=art.TrainConfig(learning_rate=7e-6),
        _config=art.dev.TrainConfig(allow_training_without_logprobs=True),
        verbose=True,
    )

In [None]:
import asyncio
from typing import AsyncIterable, Any, Coroutine, Iterable, TypeVar

T = TypeVar("T")


async def test() -> int: ...


d = test()


async def buffer(
    iterable: Iterable[Coroutine[Any, Any, T]], *, max_concurrent: int
) -> AsyncIterable[list[T]]:
    buffer: list[asyncio.Task[T]] = []
    iterator = iter(iterable)
    iterator_exhausted = False

    while buffer or not iterator_exhausted:
        # Fill buffer up to max_concurrent
        while (
            len([task for task in buffer if not task.done()]) < max_concurrent
            and not iterator_exhausted
        ):
            try:
                buffer.append(asyncio.create_task(next(iterator)))
            except StopIteration:
                iterator_exhausted = True
                break
        if not buffer:
            break
        await asyncio.wait(buffer, return_when=asyncio.FIRST_COMPLETED)
        for task in buffer:
            if task.done() and (exception := task.exception()):
                raise exception
        results = []
        for task in list(buffer):
            if task.done():
                results.append(task.result())
                buffer.remove(task)
            else:
                break
        if results:
            yield results


GROUPS_PER_BATCH = 4
ROLLOUTS_PER_GROUP = 4

# Calculate the number of digits needed for batch numbering
total_batches = (len(instances) + GROUPS_PER_BATCH - 1) // GROUPS_PER_BATCH
num_digits = len(str(total_batches - 1))

async for trajectory_groups in buffer(
    (
        art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(
                    rollout(model, instance) for _ in range(ROLLOUTS_PER_GROUP)
                )
                for instance in instances[start : start + GROUPS_PER_BATCH]
            ),
            pbar_desc=f"gather/{start // GROUPS_PER_BATCH:0{num_digits}d}",
        )
        for start in range(0, len(instances), GROUPS_PER_BATCH)
    ),
    max_concurrent=3,
):
    await model.train(
        [g for gs in trajectory_groups for g in gs],
        _config=art.dev.TrainConfig(allow_training_without_logprobs=True),
    )


# async for trajectory_groups in art.trajectory_group_batches(
#     (
#         art.TrajectoryGroup(rollout(model, instance) for _ in range(4))
#         for instance in instances
#     ),
#     batch_size=4,
#     max_concurrent_batches=3,
#     skip_batches=await model.get_step(),
# ):
#     await model.train(
#         trajectory_groups,
#         _config=art.dev.TrainConfig(allow_training_without_logprobs=True),
#     )