In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
import art
import asyncio
from dotenv import load_dotenv
import json
import random
import re
from typing import TypedDict

from art.local import LocalBackend

load_dotenv()


class TemporalCluePuzzle(TypedDict):
    num_clues: int
    prompt: str
    solution: dict[str, str]


puzzles_path = "../data/temporal-clue/puzzles.json"
puzzles: list[TemporalCluePuzzle] = json.loads(open(puzzles_path).read())
val_puzzles = puzzles[:64]
test_puzzles = puzzles[64:128]
train_puzzles = puzzles[128:]
random.seed(42)
random.shuffle(train_puzzles)


async def rollout(model: art.Model, puzzle: TemporalCluePuzzle) -> art.Trajectory:
    messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}]
    client = model.openai_client()
    chat_completion = await client.chat.completions.create(
        messages=messages, model=model.name
    )
    choice = chat_completion.choices[0]
    content = choice.message.content
    assert isinstance(content, str)
    num_correct = 0
    for key, value in puzzle["solution"].items():
        if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
            match = matches[-1]
            if match.strip().lower() == value.lower():
                num_correct += 1
    reward = acc = num_correct / len(puzzle["solution"])
    return art.Trajectory(
        messages_and_choices=[*messages, choice], reward=reward, metrics={"acc": acc}
    )


model = art.TrainableModel(
    name="001",
    project="temporal-clue",
    base_model="Qwen/Qwen2.5-14B-Instruct",
    _internal_config={
        "engine_args": {
            "tensor_parallel_size": 2,
            "gpu_memory_utilization": 0.6,
            "max_num_seqs": 512,
        },
        "torchtune_args": {
            "model": "qwen2_5_14b_instruct",
            "model_type": "QWEN2",
            "async_weight_syncing": True,
        },
    },
)
backend = LocalBackend()
await model.register(backend)

stride = 8
for i in range(await model.get_step(), 1_000):
    val_groups, train_groups = await asyncio.gather(
        art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(rollout(model, puzzle) for _ in range(2))
                for puzzle in val_puzzles
            ),
            pbar_desc="val",
            pbar_total_completion_tokens=False,
        ),
        art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(rollout(model, puzzle) for _ in range(50))
                for puzzle in train_puzzles[i * stride : (i + 1) * stride]
            ),
            pbar_desc="train",
            pbar_total_completion_tokens=False,
        ),
    )
    await model.log(val_groups)
    await model.delete_checkpoints()
    await model.train(
        train_groups,
        config=art.TrainConfig(learning_rate=5e-6),
    )

[34m[1mwandb[0m: Currently logged in as: [33mbradhilton[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO 07-09 16:31:05 [__init__.py:244] Automatically detected platform cuda.
INFO 07-09 16:31:17 [__init__.py:244] Automatically detected platform cuda.
/root/.cache/huggingface/hub/models--Qwen--Qwen2.5-14B-Instruct/snapshots/cf98f3b3bbb457ad9e2bb7baf9a0125b6b88caa8
INFO 07-09 16:31:33 [config.py:823] This model supports multiple tasks: {'generate', 'classify', 'score', 'reward', 'embed'}. Defaulting to 'generate'.
INFO 07-09 16:31:33 [config.py:1946] Defaulting to use mp for distributed inference
INFO 07-09 16:31:33 [config.py:2195] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 07-09 16:31:38 [__init__.py:244] Automatically detected platform cuda.
INFO 07-09 16:31:43 [core.py:455] Waiting for init message from front-end.
INFO 07-09 16:31:43 [core.py:70] Initializing a V1 LLM engine (v0.9.1) with config: model='Qwen/Qwen2.5-14B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-14B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, ov

Loading safetensors checkpoint shards:   0% Completed | 0/8 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  12% Completed | 1/8 [00:11<01:17, 11.05s/it]
Loading safetensors checkpoint shards:  25% Completed | 2/8 [00:21<01:04, 10.71s/it]
Loading safetensors checkpoint shards:  38% Completed | 3/8 [00:31<00:52, 10.58s/it]
Loading safetensors checkpoint shards:  50% Completed | 4/8 [00:42<00:42, 10.65s/it]
Loading safetensors checkpoint shards:  62% Completed | 5/8 [00:53<00:31, 10.55s/it]
Loading safetensors checkpoint shards:  75% Completed | 6/8 [01:01<00:19,  9.78s/it]
Loading safetensors checkpoint shards:  88% Completed | 7/8 [01:04<00:07,  7.54s/it]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [01:14<00:00,  8.42s/it]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [01:14<00:00,  9.32s/it]
[1;36m(VllmWorker rank=0 pid=53610)[0;0m 


[1;36m(VllmWorker rank=0 pid=53610)[0;0m INFO 07-09 16:33:09 [default_loader.py:272] Loading weights took 74.67 seconds
[1;36m(VllmWorker rank=1 pid=53611)[0;0m INFO 07-09 16:33:09 [default_loader.py:272] Loading weights took 74.43 seconds
[1;36m(VllmWorker rank=0 pid=53610)[0;0m INFO 07-09 16:33:10 [gpu_model_runner.py:1624] Model loading took 13.9282 GiB and 75.593227 seconds
[1;36m(VllmWorker rank=1 pid=53611)[0;0m INFO 07-09 16:33:10 [gpu_model_runner.py:1624] Model loading took 13.9282 GiB and 75.682700 seconds
[1;36m(VllmWorker rank=0 pid=53610)[0;0m [1;36m(VllmWorker rank=1 pid=53611)[0;0m INFO 07-09 16:33:25 [backends.py:462] Using cache directory: /root/.cache/vllm/torch_compile_cache/b49082caa6/rank_0_0 for vLLM's torch.compile
INFO 07-09 16:33:25 [backends.py:462] Using cache directory: /root/.cache/vllm/torch_compile_cache/b49082caa6/rank_1_0 for vLLM's torch.compile
[1;36m(VllmWorker rank=1 pid=53611)[0;0m INFO 07-09 16:33:25 [backends.py:472] Dynamo bytecode

val:   0%|          | 0/128 [00:00<?, ?it/s]

train:   0%|          | 0/400 [00:00<?, ?it/s]

Packed 381 trajectories into 92 sequences of length 6144


train:   0%|          | 0/46 [00:00<?, ?it/s]

ERROR 07-09 16:46:12 [multiproc_executor.py:140] Worker proc VllmWorker-0 died unexpectedly, shutting down executor.


[rank1]:[W709 16:46:12.010522873 TCPStore.cpp:125] [c10d] recvValue failed on SocketImpl(fd=77, addr=[localhost]:47192, remote=[localhost]:51839): Connection reset by peer
Exception raised from recvBytes at /pytorch/torch/csrc/distributed/c10d/Utils.hpp:675 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7ff95e9785e8 in /root/sky_workdir/.venv/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x5ba8afe (0x7ff947836afe in /root/sky_workdir/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x5baaecf (0x7ff947838ecf in /root/sky_workdir/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x5bab74a (0x7ff94783974a in /root/sky_workdir/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so)
frame #4: c10d::TCPStore::check(std::vector<std::__cxx11::basic_st