In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import peft # type: ignore
import unsloth # type: ignore
from art.dev.model import get_model_config
from art.local.state import ModelState

config = get_model_config(
    "Qwen/Qwen2.5-72B-Instruct",
    output_dir="./.art/models/test",
    config={
        "engine_args": {
            "enable_sleep_mode": True,
            "enforce_eager": True,
            "gpu_memory_utilization": 0.9,
        },
        "peft_args": {
            # "use_gradient_checkpointing": False,
        },
    },
)
state = ModelState(config)

In [3]:
import asyncio
from art.local.vllm import set_vllm_log_file
import random
import torch
import vllm


set_vllm_log_file("./vllm.log")

num_tokens = 16384


async def warmup(request_id: str) -> None:
    max_tokens = random.randint(0, num_tokens * 2)
    async for _ in state.vllm.async_engine.generate(
        prompt={
            "prompt_token_ids": torch.randint(
                0,
                int(state.tokenizer.vocab_size),  # type: ignore
                (
                    max(
                        (
                            (
                                state.vllm.async_engine.engine.cache_config.num_gpu_blocks  # type: ignore
                                * state.vllm.async_engine.engine.cache_config.block_size
                            )
                            // state.vllm.async_engine.engine.scheduler_config.max_num_seqs
                        )
                        - 16
                        - max_tokens,
                        1,
                    ),
                ),
            ).tolist(),
        },
        sampling_params=vllm.SamplingParams(max_tokens=max_tokens),
        request_id=request_id,
    ):
        pass


warmup_future = asyncio.gather(
    *(
        warmup(f"{i}")
        for i in range(state.vllm.async_engine.engine.scheduler_config.max_num_seqs)
    )
)

In [6]:
warmup_future.cancel()
await asyncio.sleep(0.01)
try:
    warmup_future.result()
except asyncio.CancelledError:
    pass

In [None]:
import asyncio
from art.local.train import train

results_queue = asyncio.Queue()
train_task = asyncio.create_task(train(state.trainer, results_queue))

In [None]:
from art.types import TrainConfig
from art.local.service import TrainInputs
import time
import torch
from typing import cast

seq_len = 32768
batch_size = 1
shape = (batch_size, seq_len)
num_steps = 1
async with state.vllm.train_mode():
    start_time = time.time()
    for _ in range(num_steps):
        state.inputs_queue.put_nowait(
            TrainInputs(
                tokens=torch.randint(0, cast(int, state.tokenizer.vocab_size), shape),
                group_ids=torch.randint(0, 10, shape),
                parent_ids=torch.randint(0, 10, shape),
                input_pos=torch.tensor([list(range(shape[1]))]),
                assistant_mask=torch.ones(shape, dtype=torch.bool),
                logprobs=torch.zeros(shape),
                advantages=torch.zeros(shape),
                weights=torch.ones(shape),
                config=TrainConfig(lr=1e-7, kl_coef=0.01)
            )
        )
        done, _ = await asyncio.wait(
            [asyncio.create_task(results_queue.get()), train_task],
            return_when=asyncio.FIRST_COMPLETED,
        )
        for task in done:
            result = task.result()
            # If `result` is `None`, the training task finished somehow.
            assert result is not None, "The training task should never finish."
            results_queue.task_done()
            display(result)
    total_tokens = num_steps * batch_size * seq_len
    elapsed_time = time.time() - start_time
    tokens_per_second = total_tokens / elapsed_time
    print(f"Tokens per second: {tokens_per_second:.2f} tokens/s")
    print(f"Total time: {elapsed_time:.2f} seconds")