# GSW Sleep-Time Agent — Multi-Turn GRPO Training with ART

Train a Qwen3-30B-A3B (MoE) model to explore GSW structures and create bridge QA pairs
using [OpenPipe ART](https://github.com/openpipe/ART) for multi-turn reinforcement learning.

**Architecture:**
- ART handles inference (vLLM), LoRA training (Unsloth/torchtune), and GRPO updates
- GSW tools execute locally against `GSWEnvironment.step()`
- Reward: bridge-F1 against MuSiQue gold answers (verifiable, no LLM judge needed)

**Backends:**
- `ServerlessBackend`: Qwen3-30B-A3B on ART cloud (requires WANDB_API_KEY)
- `LocalBackend`: Qwen3-30B-A3B on local GPUs (4x A6000 48GB)

## Prerequisites / Troubleshooting

Before running this notebook:

1. **Fresh kernel**: If you encounter CUDA errors, restart the kernel completely
   (Kernel > Restart) before re-running.

2. **flash-attn version**: Must be >= 2.8.3. Check with:
   ```python
   import flash_attn; print(flash_attn.__version__)
   ```
   If outdated: `pip install flash-attn>=2.8.3 --no-build-isolation`

3. **Stale compiled cache**: Delete `./unsloth_compiled_cache/` if you upgraded
   PyTorch, CUDA, or unsloth since the last run.

4. **GPU memory**: Qwen3-30B-A3B requires ~20GB+ VRAM. Ensure no other processes
   hold GPU memory. Check with `nvidia-smi`.

## Configuration

In [None]:
import os

# ---- Model download directory ----
os.environ["HF_HOME"] = "/mnt/SSD3/yigit"

# ---- Backend Selection ----
USE_SERVERLESS = False   # ART cloud (needs WANDB_API_KEY)
USE_LOCAL = True         # Local GPUs (4x A6000 48GB)

# ---- Model ----
BASE_MODEL = "Qwen/Qwen3-30B-A3B"  # MoE: 30B total, ~3B active

# ---- Training ----
MAX_TURNS = 30           # Max tool calls per episode
ROLLOUTS_PER_GROUP = 4   # GRPO group size
GROUPS_PER_STEP = 2      # Scenarios per training step
NUM_EPOCHS = 5
LEARNING_RATE = 1e-5
MAX_STEPS = 200          # Stop after this many steps (None for full run)
VALIDATION_INTERVAL = 10 # Run validation every N steps

# ---- Data ----
INDEX_PATH = "/home/yigit/codebase/gsw-memory/data/rl_training/index.json"
VAL_SPLIT = 0.05

# ---- API Keys ----
# os.environ["WANDB_API_KEY"] = ""  # Required for both backends

print(f"Backend: {'Serverless' if USE_SERVERLESS else 'Local'}")
print(f"Model: {BASE_MODEL}")
print(f"GRPO group size: {ROLLOUTS_PER_GROUP}")
print(f"HF_HOME: {os.environ['HF_HOME']}")

## Load Training Data

In [None]:
import json
import random

with open(INDEX_PATH) as f:
    all_scenarios = json.load(f)

random.seed(42)
random.shuffle(all_scenarios)

n_val = max(1, int(len(all_scenarios) * VAL_SPLIT))
val_scenarios = all_scenarios[:n_val]
train_scenarios = all_scenarios[n_val:]

print(f"Total scenarios: {len(all_scenarios)}")
print(f"Train: {len(train_scenarios)}, Val: {len(val_scenarios)}")
print(f"\nSample scenario:")
print(f"  Question: {train_scenarios[0]['question']}")
print(f"  Answer: {train_scenarios[0]['answer']}")
print(f"  GSW dirs: {train_scenarios[0]['gsw_dirs']}")
print(f"  Num hops: {train_scenarios[0].get('num_hops', '?')}")

## Model + Backend Setup

In [None]:
import os
import gc
import torch

# ---- 1. Clear any stale CUDA state ----
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        with torch.cuda.device(i):
            torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.synchronize()
    print("CUDA caches cleared and synchronized.")
else:
    print("WARNING: No CUDA GPUs detected!")

# ---- 2. Report GPU status ----
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    name = torch.cuda.get_device_name(i)
    total = torch.cuda.get_device_properties(i).total_memory / 1024**3
    free, _ = torch.cuda.mem_get_info(i)
    free_gb = free / 1024**3
    print(f"  GPU {i}: {name} | {total:.1f} GB total | {free_gb:.1f} GB free")

# ---- 3. Validate flash-attn version ----
try:
    import flash_attn
    fa_version = flash_attn.__version__
    print(f"\nflash-attn version: {fa_version}")
    from packaging.version import Version
    if Version(fa_version) < Version("2.8.3"):
        print(f"  WARNING: flash-attn {fa_version} < 2.8.3 (required)")
        print(f"  This may cause CUDA illegal memory access errors.")
        print(f"  Fix: pip install flash-attn>=2.8.3 --no-build-isolation")
except ImportError:
    print("\nflash-attn: not installed")

# ---- 4. Check for stale unsloth compiled cache ----
from pathlib import Path
cache_dir = Path("./unsloth_compiled_cache")
if cache_dir.exists():
    num_files = sum(1 for _ in cache_dir.rglob("*") if _.is_file())
    print(f"\nunsloth_compiled_cache: {num_files} files found")
    print("  If you encounter CUDA errors, try deleting this directory:")
    print(f"  rm -rf {cache_dir.resolve()}")
else:
    print("\nunsloth_compiled_cache: not found (clean state)")

# ---- 5. Set vLLM memory limit ----
os.environ.setdefault("VLLM_GPU_MEMORY_UTILIZATION", "0.85")

# Uncomment for synchronous CUDA errors (slower but pinpoints location):
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

print("\nGPU diagnostics complete. Proceed to model registration.")

In [None]:
import art
from art import dev
import shutil

# Clear stale ART state (needed when changing _internal_config)
art_dir = Path("./.art/gsw-sleep-time")
if art_dir.exists():
    shutil.rmtree(art_dir)
    print("Cleared stale .art state")

# Qwen3-30B-A3B is MoE — Unsloth doesn't support fast_inference for MoE.
# Use _decouple_vllm_and_unsloth=True to run Unsloth and vLLM separately.
art_model = art.TrainableModel(
    name="gsw-sleep-agent-001",
    project="gsw-sleep-time",
    base_model=BASE_MODEL,
    _internal_config=dev.InternalModelConfig(
        _decouple_vllm_and_unsloth=True,
        init_args={
            "load_in_4bit": False,
            "max_seq_length": 8192,
        },
        engine_args={
            "tensor_parallel_size": 2,
            "enforce_eager": True,
            "max_model_len": 8192,
            "dtype": "bfloat16",
            "gpu_memory_utilization": 0.75,
            "enable_sleep_mode": True,
        },
    ),
)

if USE_SERVERLESS:
    from art.serverless.backend import ServerlessBackend
    backend = ServerlessBackend()
    print("Using ServerlessBackend (ART cloud)")
else:
    from art.local import LocalBackend
    backend = LocalBackend(path="./.art")
    print("Using LocalBackend (local GPUs)")

try:
    await art_model.register(backend)
except RuntimeError as e:
    if "CUDA" in str(e) or "illegal memory access" in str(e):
        print(f"\nCUDA error during registration: {e}")
        print("\nAttempting recovery:")
        print("  1. Clearing unsloth compiled cache...")
        cache_dir = Path("./unsloth_compiled_cache")
        if cache_dir.exists():
            shutil.rmtree(cache_dir)
            print("     Deleted unsloth_compiled_cache/")
        print("  2. Clearing CUDA state...")
        torch.cuda.empty_cache()
        gc.collect()
        print("\n  RECOMMENDED: Restart the Jupyter kernel and re-run all cells.")
        raise
    else:
        raise

print(f"Model registered. Current step: {await art_model.get_step()}")

## GSW Tool Definitions

Define the 10 GSW tools in OpenAI function-calling format.
These schemas are passed to the model via `tools=` so it knows what tools are available.

In [None]:
GSW_TOOL_SCHEMAS = [
    # ---- Discovery ----
    {
        "type": "function",
        "function": {
            "name": "get_entity_documents",
            "description": "Get list of document IDs that mention this entity.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Entity to search for"}
                },
                "required": ["entity_name"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_document_entities",
            "description": "Get list of entities mentioned in this document.",
            "parameters": {
                "type": "object",
                "properties": {
                    "doc_id": {"type": "string", "description": "Document ID (e.g., 'doc_3')"}
                },
                "required": ["doc_id"]
            }
        }
    },
    # ---- Context ----
    {
        "type": "function",
        "function": {
            "name": "get_entity_context",
            "description": "Get all QA pairs, roles, states, and relationships for an entity. Pass doc_id for single doc, or omit for merged context from all docs.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Entity to get context for"},
                    "doc_id": {"type": "string", "description": "Optional document ID (e.g., 'doc_4'). Omit for merged context."}
                },
                "required": ["entity_name"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "reconcile_entity_across_docs",
            "description": "Merge all information about an entity from all documents into unified view. Use this to see complete picture of an entity.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Entity to reconcile"}
                },
                "required": ["entity_name"]
            }
        }
    },
    # ---- Bridges ----
    {
        "type": "function",
        "function": {
            "name": "create_bridge_qa",
            "description": "Create a bridge QA pair connecting information across documents.",
            "parameters": {
                "type": "object",
                "properties": {
                    "question": {"type": "string", "description": "Bridge question"},
                    "answer": {"type": "string", "description": "Bridge answer"},
                    "source_docs": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "List of source document IDs"
                    },
                    "reasoning": {"type": "string", "description": "How this bridge was derived"},
                    "confidence": {"type": "number", "description": "Confidence score (0-1, default 0.9)"},
                    "entities_involved": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "Entities mentioned in bridge"
                    }
                },
                "required": ["question", "answer", "source_docs", "reasoning"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_bridge_statistics",
            "description": "Get statistics on bridges created so far (total, coverage, quality).",
            "parameters": {
                "type": "object",
                "properties": {}
            }
        }
    },
    # ---- Strategy ----
    {
        "type": "function",
        "function": {
            "name": "mark_entity_explored",
            "description": "Mark an entity as explored. Call this when done exploring an entity.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Entity that was explored"},
                    "num_bridges_created": {"type": "integer", "description": "Number of bridges created for this entity"}
                },
                "required": ["entity_name"]
            }
        }
    },
    # ---- Exploration Tracking ----
    {
        "type": "function",
        "function": {
            "name": "plan_entity_exploration",
            "description": "Create exploration plan for entity showing all relationships to check. Call ONCE after reconcile_entity_across_docs.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Entity being explored"},
                    "relationships": {"type": "object", "description": "merged_relationships dict from reconcile_entity_across_docs output"}
                },
                "required": ["entity_name", "relationships"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "mark_relationship_explored",
            "description": "Mark a relationship as explored after checking all its documents. Returns updated checklist.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Main entity being explored"},
                    "relationship_name": {"type": "string", "description": "Name of the related entity just explored"},
                    "bridges_created": {"type": "integer", "description": "Number of bridges created for this relationship"}
                },
                "required": ["entity_name", "relationship_name"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_exploration_status",
            "description": "Check which relationships explored vs pending. Use before mark_entity_explored to verify completeness.",
            "parameters": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string", "description": "Entity to check status for"}
                },
                "required": ["entity_name"]
            }
        }
    }
]

print(f"Defined {len(GSW_TOOL_SCHEMAS)} tool schemas:")
for t in GSW_TOOL_SCHEMAS:
    print(f"  - {t['function']['name']}")

## Rollout Function

Runs one episode: the agent explores GSW structures via tool calls,
creating bridge QA pairs. The trajectory captures all messages + tool calls.
Reward is computed from bridge quality (F1 against gold answers).

In [None]:
import json
import traceback
from pathlib import Path

from openai import AsyncOpenAI

import art

from gsw_memory.sleep_time.environment import GSWEnvironment
from gsw_memory.sleep_time.reward import compute_reward
from gsw_memory.sleep_time.prompts import SLEEP_TIME_SYSTEM_PROMPT
from gsw_memory.sleep_time.entity_search import EntitySearcher


async def rollout(model: art.Model, scenario: dict, step: int = 0) -> art.Trajectory:
    """
    Run one multi-turn exploration episode.

    The agent calls GSW tools to explore entity relationships and create
    bridge QA pairs. Tools execute locally against GSWEnvironment.
    """
    # Build GSW environment for this episode
    gsw_dirs = scenario["gsw_dirs"]
    gsw_root = str(Path(gsw_dirs[0]).parent)
    searcher = EntitySearcher(path_to_gsw_files=gsw_root, verbose=False)
    env = GSWEnvironment(
        entity_searcher=searcher,
        question=scenario["question"],
        gold_answer=scenario["answer"],
        gold_decomposition=scenario.get("decomposition", []),
        max_turns=MAX_TURNS,
    )
    env.reset()

    user_content = (
        f"Question: {scenario['question']}\n\n"
        f"Explore the GSW corpus to find multi-hop bridge connections "
        f"that help answer this question. Use the available tools systematically."
    )

    traj = art.Trajectory(
        reward=0.0,
        messages_and_choices=[
            {"role": "system", "content": SLEEP_TIME_SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ],
        tools=GSW_TOOL_SCHEMAS,
        metadata={
            "scenario_id": scenario.get("id", ""),
            "step": step,
            "question": scenario["question"],
        },
    )

    client = AsyncOpenAI(
        base_url=model.inference_base_url,
        api_key=model.inference_api_key,
    )

    for turn in range(MAX_TURNS):
        try:
            response = await client.chat.completions.create(
                model=model.get_inference_name(),
                temperature=0.7,
                messages=traj.messages(),
                tools=traj.tools,
            )
        except Exception as e:
            traj.log(f"Inference error at turn {turn}: {e}")
            break

        response_message = response.choices[0].message
        traj.messages_and_choices.append(response.choices[0])

        # No tool calls = agent stopped (freeform response)
        if not response_message.tool_calls:
            break

        # Execute each tool call against the GSW environment
        try:
            for tool_call in response_message.tool_calls:
                tool_name = tool_call.function.name
                tool_args = json.loads(tool_call.function.arguments)

                # Execute against live GSW environment
                obs, done = env.step(tool_name, tool_args)

                traj.messages_and_choices.append({
                    "role": "tool",
                    "tool_call_id": tool_call.id,
                    "name": tool_name,
                    "content": obs,
                })

                if done:
                    break
        except Exception as e:
            traj.log(f"Tool execution error at turn {turn}: {e}")
            traj.messages_and_choices.append({
                "role": "tool",
                "tool_call_id": tool_call.id,
                "name": tool_name,
                "content": json.dumps({"error": str(e)}),
            })

        if env.done:
            break

    # Compute reward from bridges created during the episode
    bridges = env.get_bridges()
    traj.reward = compute_reward(
        bridges=bridges,
        gold_answer=scenario["answer"],
        gold_decomposition=scenario.get("decomposition", []),
        gold_aliases=scenario.get("answer_aliases", []),
    )
    traj.metrics["num_bridges"] = len(bridges)
    traj.metrics["num_turns"] = env.turn
    traj.metrics["env_done"] = int(env.done)

    return traj


print("Rollout function defined.")

## Test Rollout

Run a single rollout to verify the pipeline works before training.

In [None]:
test_scenario = train_scenarios[0]
print(f"Test scenario: {test_scenario['question']}")
print(f"Expected answer: {test_scenario['answer']}")
print(f"GSW dirs: {test_scenario['gsw_dirs']}")
print("-" * 60)

test_traj = await rollout(art_model, test_scenario, step=0)

print(f"\nTrajectory summary:")
print(f"  Reward: {test_traj.reward:.4f}")
print(f"  Bridges created: {test_traj.metrics.get('num_bridges', 0)}")
print(f"  Turns used: {test_traj.metrics.get('num_turns', 0)}")
print(f"  Episode done: {test_traj.metrics.get('env_done', 0)}")

# Show tool calls
messages = test_traj.messages()
print(f"\nMessages ({len(messages)} total):")
for i, msg in enumerate(messages):
    role = msg.get("role", "unknown")
    content = msg.get("content", "")
    tool_calls = msg.get("tool_calls", [])
    if role == "system":
        print(f"  [{i}] SYSTEM: (system prompt, {len(content)} chars)")
    elif role == "user":
        print(f"  [{i}] USER: {content[:100]}...")
    elif role == "assistant":
        if tool_calls:
            for tc in tool_calls:
                print(f"  [{i}] ASSISTANT tool_call: {tc['function']['name']}({tc['function']['arguments'][:80]})")
        if content:
            print(f"  [{i}] ASSISTANT: {content[:100]}...")
    elif role == "tool":
        name = msg.get("name", "?")
        print(f"  [{i}] TOOL ({name}): {content[:100]}...")

## Training Loop

In [None]:
from art.utils import iterate_dataset

training_iterator = iterate_dataset(
    train_scenarios,
    groups_per_step=GROUPS_PER_STEP,
    num_epochs=NUM_EPOCHS,
    initial_step=await art_model.get_step(),
)

for batch in training_iterator:
    print(f"\n{'='*60}")
    print(f"Step {batch.step} | Epoch {batch.epoch} | Epoch step {batch.epoch_step}")
    print(f"Batch: {len(batch.items)} scenarios x {ROLLOUTS_PER_GROUP} rollouts")

    # ---- Generate rollouts ----
    train_groups = []
    for scenario in batch.items:
        train_groups.append(
            art.TrajectoryGroup(
                rollout(art_model, scenario, step=batch.step)
                for _ in range(ROLLOUTS_PER_GROUP)
            )
        )

    finished_groups = await art.gather_trajectory_groups(
        train_groups,
        pbar_desc=f"step {batch.step}",
        max_exceptions=ROLLOUTS_PER_GROUP * len(batch.items),
    )

    # ---- Log metrics ----
    all_rewards = []
    all_bridges = []
    all_turns = []
    for group in finished_groups:
        for traj in group.trajectories:
            all_rewards.append(traj.reward)
            all_bridges.append(traj.metrics.get("num_bridges", 0))
            all_turns.append(traj.metrics.get("num_turns", 0))

    if all_rewards:
        print(f"  Rewards: mean={sum(all_rewards)/len(all_rewards):.3f}, "
              f"max={max(all_rewards):.3f}, min={min(all_rewards):.3f}")
        print(f"  Bridges: mean={sum(all_bridges)/len(all_bridges):.1f}, "
              f"max={max(all_bridges)}")
        print(f"  Turns: mean={sum(all_turns)/len(all_turns):.1f}")

    # ---- Validation ----
    if batch.step % VALIDATION_INTERVAL == 0:
        print(f"  Running validation ({len(val_scenarios)} scenarios)...")
        val_groups = []
        for scenario in val_scenarios:
            val_groups.append(
                art.TrajectoryGroup(
                    [rollout(art_model, scenario, step=batch.step)]
                )
            )
        finished_val = await art.gather_trajectory_groups(
            val_groups,
            pbar_desc="val",
            max_exceptions=len(val_scenarios),
        )
        val_rewards = [
            t.reward
            for g in finished_val
            for t in g.trajectories
        ]
        if val_rewards:
            print(f"  Val rewards: mean={sum(val_rewards)/len(val_rewards):.3f}, "
                  f"max={max(val_rewards):.3f}")
        await art_model.log(finished_val, split="val")

    # ---- Train (GRPO update) ----
    await art_model.delete_checkpoints()
    await art_model.train(
        finished_groups,
        config=art.TrainConfig(learning_rate=LEARNING_RATE),
    )
    print(f"  Training step {batch.step} complete.")

    if MAX_STEPS and batch.step >= MAX_STEPS:
        print(f"\nReached max_steps={MAX_STEPS}. Stopping.")
        break

print("\nTraining complete!")

## Test Trained Model

In [None]:
print("Testing the trained model...\n")

test_scenario = val_scenarios[0]
print(f"Question: {test_scenario['question']}")
print(f"Expected answer: {test_scenario['answer']}")
print("-" * 60)

result_traj = await rollout(art_model, test_scenario, step=0)

print(f"\nReward: {result_traj.reward:.4f}")
print(f"Bridges: {result_traj.metrics.get('num_bridges', 0)}")
print(f"Turns: {result_traj.metrics.get('num_turns', 0)}")

# Show full trajectory
messages = result_traj.messages()
print(f"\n--- Full Trajectory ({len(messages)} messages) ---")
for i, msg in enumerate(messages):
    role = msg.get("role", "unknown")
    content = msg.get("content", "")
    tool_calls = msg.get("tool_calls", [])

    if role == "system":
        print(f"\n[SYSTEM]: (prompt, {len(content)} chars)")
    elif role == "user":
        print(f"\n[USER]: {content}")
    elif role == "assistant":
        if tool_calls:
            for tc in tool_calls:
                print(f"\n[ASSISTANT → {tc['function']['name']}]: {tc['function']['arguments'][:200]}")
        if content:
            print(f"\n[ASSISTANT]: {content[:300]}")
    elif role == "tool":
        name = msg.get("name", "?")
        display = content[:300] + "..." if len(content) > 300 else content
        print(f"\n[TOOL {name}]: {display}")

## Save Model

ART saves checkpoints automatically. For explicit export:

In [None]:
# The model checkpoints are saved by ART in the backend's path (.art/ for local)
# For ServerlessBackend, checkpoints are W&B Artifacts

current_step = await art_model.get_step()
print(f"Final model step: {current_step}")
print(f"Checkpoints saved at: .art/ (LocalBackend) or W&B Artifacts (ServerlessBackend)")