# GSW Sleep-Time Agent â€” Standalone Multi-Turn GRPO

Train Qwen3-30B-A3B (MoE) to explore GSW structures and create bridge QA pairs.

**No ART framework** â€” uses Unsloth + manual GRPO directly.

- Unsloth loads model with `fast_inference=False` (required for MoE)
- Multi-turn rollouts via `model.generate()` + local GSW tool execution
- Manual GRPO: advantage-weighted policy gradient with tool-response masking
- Reward: bridge-F1 against MuSiQue gold answers (verifiable, no LLM judge)

In [1]:
import os
os.environ["HF_HOME"] = "/mnt/SSD3/yigit"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,3"
# ---- Model ----
BASE_MODEL = "Qwen/Qwen3-30B-A3B"
MAX_SEQ_LENGTH = 32768
LORA_RANK = 32

# ---- GPU ----
EMBEDDING_GPU = 2  # GPU for EntitySearcher's embedding model (GPUs 0+1 used by training model)

# ---- Rollout ----
MAX_TURNS = 30
TEMPERATURE = 0.7
MAX_NEW_TOKENS = 32768  # per generation step

# ---- GRPO ----
ROLLOUTS_PER_PROMPT = 4   # group size
NUM_TRAIN_PROMPTS = 2     # prompts per step
KL_COEF = 0.001
CLIP_RATIO = 0.2

# ---- Training ----
MAX_STEPS = 200
LEARNING_RATE = 5e-6
MAX_GRAD_NORM = 1.0
SAVE_EVERY = 50
LOG_EVERY = 5

# ---- Data ----
INDEX_PATH = "/home/yigit/codebase/gsw-memory/data/rl_training/index.json"
VAL_SPLIT = 0.05
OUTPUT_DIR = "outputs/gsw_grpo"

print(f"Model: {BASE_MODEL}")
print(f"LoRA rank: {LORA_RANK}")
print(f"GRPO group size: {ROLLOUTS_PER_PROMPT}")
print(f"Training GPUs: 0, 1 | Embedding GPU: {EMBEDDING_GPU}")
print(f"HF_HOME: {os.environ['HF_HOME']}")

Model: Qwen/Qwen3-30B-A3B
LoRA rank: 32
GRPO group size: 4
Training GPUs: 0, 1 | Embedding GPU: 2
HF_HOME: /mnt/SSD3/yigit


## GPU Diagnostics

In [2]:
import gc
import torch

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        with torch.cuda.device(i):
            torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.synchronize()

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    name = torch.cuda.get_device_name(i)
    total = torch.cuda.get_device_properties(i).total_memory / 1024**3
    free, _ = torch.cuda.mem_get_info(i)
    print(f"  GPU {i}: {name} | {total:.1f} GB total | {free/1024**3:.1f} GB free")

PyTorch: 2.7.1+cu126
CUDA: 12.6
GPUs: 2
  GPU 0: NVIDIA RTX A6000 | 47.4 GB total | 47.1 GB free
  GPU 1: NVIDIA RTX A6000 | 47.4 GB total | 47.1 GB free


## Load Model (Unsloth MoE)

In [3]:
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=False,
    fast_inference=False,  # Not supported for MoE
    device_map="balanced",
    max_memory={0: "46GiB", 1: "46GiB"},  # Only use GPU 0 and 1
)

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj", "gate_up_proj",
    ],
    lora_alpha=LORA_RANK * 2,
    use_gradient_checkpointing=True,
    random_state=3407,
    bias="none",
)

model.print_trainable_parameters()

# Show per-GPU memory usage
for i in range(torch.cuda.device_count()):
    alloc = torch.cuda.memory_allocated(i) / 1024**3
    total = torch.cuda.get_device_properties(i).total_memory / 1024**3
    print(f"  GPU {i}: {alloc:.1f} / {total:.1f} GB allocated")

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.


Skipping import of cpp extensions due to incompatible torch version 2.7.1+cu126 for torchao version 0.14.1             Please see https://github.com/pytorch/ao/issues/2919 for more info


INFO 02-26 01:01:46 [__init__.py:235] Automatically detected platform cuda.
Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
Switching to PyTorch attention since your Xformers is broken.

/home/yigit/codebase/gsw-memory/.venv/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.3: Fast Qwen3_Moe patching. Transformers: 4.56.2. vLLM: 0.10.0.
   \\   /|    NVIDIA RTX A6000. Num GPUs = 2. Max memory: 47.402 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xfor

Loading checkpoint shards:   0%|          | 0/13 [00:00<?, ?it/s]

Unsloth: Making `model.base_model.model.model` require gradients
trainable params: 1,687,683,072 || all params: 32,219,805,696 || trainable%: 5.2380
  GPU 0: 31.6 / 47.4 GB allocated
  GPU 1: 31.6 / 47.4 GB allocated


## Load Training Data

In [4]:
import json
import random

with open(INDEX_PATH) as f:
    all_scenarios = json.load(f)

random.seed(42)
random.shuffle(all_scenarios)

n_val = max(1, int(len(all_scenarios) * VAL_SPLIT))
val_scenarios = all_scenarios[:n_val]
train_scenarios = all_scenarios[n_val:]

print(f"Total: {len(all_scenarios)}, Train: {len(train_scenarios)}, Val: {len(val_scenarios)}")
print(f"Sample: {train_scenarios[0]['question']}")
print(f"Answer: {train_scenarios[0]['answer']}")

Total: 500, Train: 475, Val: 25
Sample: What month did the person who retained trust in Longchamp go away to the Holy Land?
Answer: October


## Multi-Turn Rollout

In [5]:
import re
from pathlib import Path
from typing import Any, Dict, List, Tuple

from gsw_memory.sleep_time.environment import GSWEnvironment
from gsw_memory.sleep_time.entity_search import EntitySearcher
from gsw_memory.sleep_time.prompts import SLEEP_TIME_SYSTEM_PROMPT
from gsw_memory.sleep_time.reward import compute_reward

# Tool call regex (same as gsw_interaction.py)
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)


def build_messages(question: str) -> list[dict]:
    """Build initial chat messages for an episode."""
    return [
        {"role": "system", "content": SLEEP_TIME_SYSTEM_PROMPT},
        {
            "role": "user",
            "content": (
                f"Question: {question}\n\n"
                f"Explore the GSW corpus to find multi-hop bridge connections "
                f"that help answer this question. Use the available tools systematically."
            ),
        },
    ]


def messages_to_text(messages: list[dict]) -> str:
    """Convert messages to Qwen chat format string."""
    return tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )


def parse_tool_call(text: str) -> tuple[str, dict] | None:
    """Extract last tool call from generated text."""
    matches = list(_TOOL_CALL_RE.finditer(text))
    if not matches:
        return None
    try:
        call = json.loads(matches[-1].group(1))
        name = call.get("name", "")
        args = call.get("arguments", {})
        if isinstance(args, str):
            args = json.loads(args)
        return name, args
    except (json.JSONDecodeError, KeyError):
        return None


def run_episode(
    scenario: dict,
    model,
    tokenizer,
    max_turns: int = MAX_TURNS,
    temperature: float = TEMPERATURE,
    max_new_tokens: int = MAX_NEW_TOKENS,
) -> dict:
    """
    Run one multi-turn episode.

    Returns dict with:
        messages: full conversation (list of dicts)
        full_text: tokenized trajectory as string
        reward: scalar reward
        num_bridges: count of bridges created
        num_turns: turns used
    """
    # Build environment
    gsw_dirs = scenario["gsw_dirs"]
    gsw_root = str(Path(gsw_dirs[0]).parent)
    searcher = EntitySearcher(
        path_to_gsw_files=gsw_root,
        verbose=False,
        gpu_device=EMBEDDING_GPU,
    )
    env = GSWEnvironment(
        entity_searcher=searcher,
        question=scenario["question"],
        gold_answer=scenario["answer"],
        gold_decomposition=scenario.get("decomposition", []),
        max_turns=max_turns,
    )
    env.reset()

    messages = build_messages(scenario["question"])

    for turn in range(max_turns):
        # Generate
        prompt_text = messages_to_text(messages)
        inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                use_cache=True,
            )

        # Decode only the new tokens
        new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
        response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)

        # Append assistant message
        messages.append({"role": "assistant", "content": response_text})

        # Check for tool call
        parsed = parse_tool_call(response_text)
        if parsed is None:
            # No tool call â€” agent stopped
            break

        tool_name, tool_args = parsed

        # Execute tool
        try:
            obs, done = env.step(tool_name, tool_args)
        except Exception as e:
            obs = json.dumps({"error": str(e)})
            done = False

        # Append tool result as user message (for next generation)
        messages.append({
            "role": "user",
            "content": f"<tool_response>{obs}</tool_response>",
        })

        if done or env.done:
            break

    # Compute reward
    bridges = env.get_bridges()
    reward = compute_reward(
        bridges=bridges,
        gold_answer=scenario["answer"],
        gold_decomposition=scenario.get("decomposition", []),
        gold_aliases=scenario.get("answer_aliases", []),
    )

    # Build full text for tokenization
    full_text = tokenizer.apply_chat_template(messages, tokenize=False)

    return {
        "messages": messages,
        "full_text": full_text,
        "reward": reward,
        "num_bridges": len(bridges),
        "num_turns": env.turn,
    }


print("Rollout function defined.")

INFO:faiss.loader: Loading faiss with AVX2 support.
INFO:faiss.loader: Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
INFO:faiss.loader: Loading faiss.
INFO:faiss.loader: Successfully loaded faiss.


Rollout function defined.


## Test Rollout

In [None]:
test_scenario = train_scenarios[0]
print(f"Question: {test_scenario['question']}")
print(f"Answer: {test_scenario['answer']}")
print("-" * 60)

FastLanguageModel.for_inference(model)  # Enable fast inference mode

result = run_episode(test_scenario, model, tokenizer)

print(f"\nReward: {result['reward']:.4f}")
print(f"Bridges: {result['num_bridges']}")
print(f"Turns: {result['num_turns']}")
print(f"Messages: {len(result['messages'])}")

# Show conversation
for i, msg in enumerate(result["messages"]):
    role = msg["role"]
    content = msg["content"]
    if role == "system":
        print(f"\n[{i}] SYSTEM: ({len(content)} chars)")
    elif role == "user" and "<tool_response>" in content:
        print(f"\n[{i}] TOOL_RESPONSE: {content}...")
    elif role == "user":
        print(f"\n[{i}] USER: {content}")
    elif role == "assistant":
        print(f"\n[{i}] ASSISTANT: {content}")

Question: What month did the person who retained trust in Longchamp go away to the Holy Land?
Answer: October
------------------------------------------------------------
Loading first 50 GSW files...
Loaded 50 GSW structures from 50 documents
INFO 02-25 18:08:05 [config.py:538] Found sentence-transformers modules configuration.
INFO 02-25 18:08:05 [config.py:558] Found pooling configuration.


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 02-25 18:08:05 [config.py:1604] Using max model len 40960
INFO 02-25 18:08:05 [arg_utils.py:1551] (Enabling) chunked prefill by default
INFO 02-25 18:08:05 [arg_utils.py:1554] (Enabling) prefix caching by default
INFO 02-25 18:08:05 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=8192.


Skipping import of cpp extensions due to incompatible torch version 2.7.1+cu126 for torchao version 0.14.1             Please see https://github.com/pytorch/ao/issues/2919 for more info


INFO 02-25 18:08:11 [__init__.py:235] Automatically detected platform cuda.
INFO 02-25 18:08:11 [core.py:572] Waiting for init message from front-end.
INFO 02-25 18:08:11 [core.py:71] Initializing a V1 LLM engine (v0.10.0) with config: model='Qwen/Qwen3-Embedding-8B', speculative_config=None, tokenizer='Qwen/Qwen3-Embedding-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=40960, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_tr

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.71it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.30it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.18it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.26it/s]



INFO 02-25 18:08:17 [default_loader.py:262] Loading weights took 3.28 seconds
INFO 02-25 18:08:18 [gpu_model_runner.py:1892] Model loading took 14.1355 GiB and 4.022104 seconds
INFO 02-25 18:08:24 [backends.py:530] Using cache directory: /home/yigit/.cache/vllm/torch_compile_cache/4f805632d6/rank_0_0/backbone for vLLM's torch.compile
INFO 02-25 18:08:24 [backends.py:541] Dynamo bytecode transform time: 5.90 s
INFO 02-25 18:08:29 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 4.410 s
INFO 02-25 18:08:30 [monitor.py:34] torch.compile takes 5.90 s in total
INFO 02-25 18:08:31 [gpu_worker.py:255] Available KV cache memory: 27.69 GiB
INFO 02-25 18:08:32 [kv_cache_utils.py:833] GPU KV cache size: 201,616 tokens
INFO 02-25 18:08:32 [kv_cache_utils.py:837] Maximum concurrency for 40,960 tokens per request: 4.92x


Capturing CUDA graph shapes: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 67/67 [00:03<00:00, 16.79it/s]


INFO 02-25 18:08:36 [gpu_model_runner.py:2485] Graph capturing finished in 4 secs, took 0.58 GiB
INFO 02-25 18:08:36 [core.py:193] init engine (profile, create kv cache, warmup model) took 18.06 seconds





Reward: 0.0000
Bridges: 0
Turns: 0
Messages: 3

[0] SYSTEM: (30464 chars)

[1] USER: Question: What month did the person who retained trust in Longchamp go away to the Holy Land?

Explore the GSW corpus to find multi-hop bridge connections that help answer this question. Use the avail

[2] ASSISTANT: <think>
Okay, let's tackle this question: "What month did the person who retained trust in Longchamp go away to the Holy Land?" 

First, I need to break down the question. The key entities here are "Longchamp" and the person who retained trust in Longchamp. The goal is to find out the month when thi


In [7]:
# Show conversation
for i, msg in enumerate(result["messages"]):
    role = msg["role"]
    content = msg["content"]
    if role == "system":
        print(f"\n[{i}] SYSTEM: ({len(content)} chars)")
    elif role == "user" and "<tool_response>" in content:
        print(f"\n[{i}] TOOL_RESPONSE: {content}...")
    elif role == "user":
        print(f"\n[{i}] USER: {content}")
    elif role == "assistant":
        print(f"\n[{i}] ASSISTANT: {content}")


[0] SYSTEM: (30464 chars)

[1] USER: Question: What month did the person who retained trust in Longchamp go away to the Holy Land?

Explore the GSW corpus to find multi-hop bridge connections that help answer this question. Use the available tools systematically.

[2] ASSISTANT: <think>
Okay, let's tackle this question: "What month did the person who retained trust in Longchamp go away to the Holy Land?" 

First, I need to break down the question. The key entities here are "Longchamp" and the person who retained trust in Longchamp. The goal is to find out the month when this person left for the Holy Land. 

I should start by exploring the entity "Longchamp" to see what information is available. Using the tools provided, I'll call reconcile_entity_across_docs("Longchamp") to get a merged view of all documents mentioning Longchamp. This will help me understand the context and any relationships associated with Longchamp.

Once I have the merged data, I'll look for any entities that are r

## GRPO Training Loop

Manual GRPO: collect rollout groups, compute advantages, do policy gradient updates
with tool-response token masking.

In [None]:
from gsw_memory.sleep_time.train import compute_advantages, mask_tool_response_tokens
from torch.optim import AdamW
from pathlib import Path
import time

# Setup output dir
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

# Optimizer
optimizer = AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=LEARNING_RATE,
    weight_decay=0.01,
)

# Reference model log-probs (for KL penalty) â€” use initial model
# We compute ref log-probs inline during the first forward pass
# For simplicity, we skip KL penalty initially and add it later if needed


def compute_grpo_loss(
    model,
    tokenizer,
    episodes: list[dict],
    advantages: list[float],
) -> torch.Tensor:
    """
    Compute GRPO policy gradient loss for a group of episodes.

    Loss = -mean(advantage_i * log_prob_i * loss_mask_i)

    Tool response tokens are masked from the loss.
    """
    total_loss = torch.tensor(0.0, device=model.device)
    n_valid = 0

    for ep, adv in zip(episodes, advantages):
        # Tokenize full trajectory
        full_text = ep["full_text"]
        encoding = tokenizer(
            full_text,
            return_tensors="pt",
            truncation=True,
            max_length=MAX_SEQ_LENGTH,
        ).to(model.device)

        input_ids = encoding["input_ids"][0]  # (seq_len,)

        # Build loss mask (zeros out tool response tokens)
        loss_mask = mask_tool_response_tokens(input_ids, tokenizer).to(model.device)

        # Find where the first assistant response starts (skip prompt tokens)
        # We only train on assistant tokens, not prompt tokens
        prompt_messages = ep["messages"][:2]  # system + user
        prompt_text = tokenizer.apply_chat_template(
            prompt_messages, tokenize=False, add_generation_prompt=True
        )
        prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
        prompt_len = len(prompt_ids)

        # Zero out prompt tokens in loss mask
        loss_mask[:prompt_len] = 0.0

        if loss_mask.sum() == 0:
            continue  # Nothing to train on

        # Forward pass
        outputs = model(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"],
        )
        logits = outputs.logits  # (1, seq_len, vocab)

        # Shift for next-token prediction
        shift_logits = logits[0, :-1, :]  # (seq_len-1, vocab)
        shift_labels = input_ids[1:]       # (seq_len-1,)
        shift_mask = loss_mask[1:]          # (seq_len-1,)

        # Log probabilities of actual tokens
        log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
        token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)

        # Masked mean log prob, weighted by advantage
        masked_log_probs = (token_log_probs * shift_mask).sum() / shift_mask.sum().clamp(min=1)
        total_loss += -adv * masked_log_probs
        n_valid += 1

    if n_valid == 0:
        return total_loss

    return total_loss / n_valid


print("GRPO loss function defined.")

In [None]:
# Training loop
all_rewards = []
scenario_idx = 0

print(f"Starting GRPO training for {MAX_STEPS} steps...")
print(f"  {NUM_TRAIN_PROMPTS} prompts x {ROLLOUTS_PER_PROMPT} rollouts = {NUM_TRAIN_PROMPTS * ROLLOUTS_PER_PROMPT} episodes/step")
print(f"  Output: {OUTPUT_DIR}")
print("=" * 60)

for step in range(1, MAX_STEPS + 1):
    step_start = time.time()

    # ---- Collect rollouts ----
    FastLanguageModel.for_inference(model)

    step_episodes = []   # list of lists: [prompt_group_1, prompt_group_2, ...]
    step_rewards = []

    for _ in range(NUM_TRAIN_PROMPTS):
        scenario = train_scenarios[scenario_idx % len(train_scenarios)]
        scenario_idx += 1

        group_episodes = []
        group_rewards = []

        for _ in range(ROLLOUTS_PER_PROMPT):
            try:
                ep = run_episode(scenario, model, tokenizer)
                group_episodes.append(ep)
                group_rewards.append(ep["reward"])
            except Exception as e:
                print(f"  Episode error: {e}")
                continue

        if group_episodes:
            # Compute GRPO advantages within this group
            advantages = compute_advantages(group_rewards)
            step_episodes.append((group_episodes, advantages))
            step_rewards.extend(group_rewards)

    if not step_episodes:
        print(f"Step {step}: No valid episodes, skipping.")
        continue

    all_rewards.extend(step_rewards)

    # ---- GRPO update ----
    FastLanguageModel.for_training(model)
    optimizer.zero_grad()

    total_loss = torch.tensor(0.0, device=model.device)
    n_groups = 0

    for group_episodes, advantages in step_episodes:
        loss = compute_grpo_loss(model, tokenizer, group_episodes, advantages)
        total_loss = total_loss + loss
        n_groups += 1

    if n_groups > 0:
        avg_loss = total_loss / n_groups
        avg_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()

    step_time = time.time() - step_start

    # ---- Logging ----
    if step % LOG_EVERY == 0 or step == 1:
        recent = all_rewards[-50:] if all_rewards else [0]
        mean_r = sum(recent) / len(recent)
        max_r = max(recent)
        nonzero = sum(1 for r in recent if r > 0) / len(recent)
        loss_val = avg_loss.item() if n_groups > 0 else 0.0
        print(
            f"Step {step}/{MAX_STEPS} | "
            f"loss={loss_val:.4f} | "
            f"reward={mean_r:.3f} (max={max_r:.3f}, nonzero={nonzero:.0%}) | "
            f"{step_time:.1f}s"
        )

    # ---- Checkpoint ----
    if step % SAVE_EVERY == 0:
        ckpt_dir = output_path / f"checkpoint_step_{step}"
        ckpt_dir.mkdir(exist_ok=True)
        model.save_pretrained(str(ckpt_dir))
        tokenizer.save_pretrained(str(ckpt_dir))
        print(f"  Saved checkpoint: {ckpt_dir}")

print("\nTraining complete!")

# Save final
final_dir = output_path / "final"
final_dir.mkdir(exist_ok=True)
model.save_pretrained(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
print(f"Final model saved to: {final_dir}")

## Validation

In [None]:
FastLanguageModel.for_inference(model)

val_rewards = []
val_bridges = []

print(f"Running validation on {len(val_scenarios)} scenarios...")
for i, scenario in enumerate(val_scenarios):
    try:
        result = run_episode(scenario, model, tokenizer)
        val_rewards.append(result["reward"])
        val_bridges.append(result["num_bridges"])
        print(f"  [{i+1}/{len(val_scenarios)}] reward={result['reward']:.3f} bridges={result['num_bridges']} turns={result['num_turns']}")
    except Exception as e:
        print(f"  [{i+1}/{len(val_scenarios)}] ERROR: {e}")

if val_rewards:
    print(f"\nValidation results:")
    print(f"  Mean reward: {sum(val_rewards)/len(val_rewards):.4f}")
    print(f"  Max reward: {max(val_rewards):.4f}")
    print(f"  Mean bridges: {sum(val_bridges)/len(val_bridges):.1f}")
    print(f"  Nonzero reward: {sum(1 for r in val_rewards if r > 0)/len(val_rewards):.0%}")

## Test Trained Model

In [None]:
FastLanguageModel.for_inference(model)

test_scenario = val_scenarios[0]
print(f"Question: {test_scenario['question']}")
print(f"Expected: {test_scenario['answer']}")
print("-" * 60)

result = run_episode(test_scenario, model, tokenizer)

print(f"\nReward: {result['reward']:.4f}")
print(f"Bridges: {result['num_bridges']}")
print(f"Turns: {result['num_turns']}")

# Show full trajectory
print(f"\n--- Trajectory ({len(result['messages'])} messages) ---")
for i, msg in enumerate(result["messages"]):
    role = msg["role"]
    content = msg["content"]
    if role == "system":
        print(f"\n[SYSTEM]: ({len(content)} chars)")
    elif role == "user" and "<tool_response>" in content:
        display = content[:300] + "..." if len(content) > 300 else content
        print(f"\n[TOOL_RESPONSE]: {display}")
    elif role == "user":
        print(f"\n[USER]: {content}")
    elif role == "assistant":
        display = content[:400] + "..." if len(content) > 400 else content
        print(f"\n[ASSISTANT]: {display}")

## Save Model

Save as LoRA adapter or merged weights.

In [None]:
# Save LoRA adapter
model.save_pretrained(str(output_path / "final"))
tokenizer.save_pretrained(str(output_path / "final"))
print(f"LoRA adapter saved to: {output_path / 'final'}")

# Optionally merge and save as full model (16-bit)
# model.save_pretrained_merged(str(output_path / "merged_16bit"), tokenizer, save_method="merged_16bit")