# DIPG Safety Gym: Training & Benchmarking Pipeline

This notebook demonstrates:
1. **Base Model Evaluation** - Benchmark the untrained model
2. **Supervised Fine-Tuning (SFT)** - Train the model on DIPG dataset
3. **Post-SFT Evaluation** - Benchmark after SFT
4. **GRPO Training** - Reinforce safety behaviors
5. **Post-GRPO Evaluation** - Final benchmark

We'll use `scripts/generate_benchmark_report.py` to quantitatively measure improvements at each stage.

## Setup & Installation

In [1]:
%%capture
import os, importlib.util
!pip install --upgrade -qqq uv
if importlib.util.find_spec("torch") is None or "COLAB_" in "".join(os.environ.keys()):
    try: import numpy; get_numpy = f"numpy=={numpy.__version__}"
    except: get_numpy = "numpy"
    !uv pip install -qqq \
        "torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" trackio \
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
        "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
        git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
elif importlib.util.find_spec("unsloth") is None:
    !uv pip install -qqq unsloth trackio
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo wandb mcp nest_asyncio matplotlib seaborn


## Load Base Model

In [2]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 4096
lora_rank = 64

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gpt-oss-20b-BF16",
    load_in_4bit = False,
    max_seq_length = max_seq_length,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
#### Unsloth: `hf_xet==1.1.10` and `ipykernel>6.30.1` breaks progress bars. Disabling for now in XET.
#### Unsloth: To re-enable progress bars, please downgrade to `ipykernel==6.30.1` or wait for a fix to
https://github.com/huggingface/xet-core/issues/526
INFO 11-29 19:33:33 [__init__.py:225] Automatically detected platform rocm.
🦥 Unsloth Zoo will now patch everything to make training faster!
Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.9: Fast Gpt_Oss patching. Transformers: 4.56.2. vLLM: 0.11.1rc3.dev39+gf417746ad.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0a0+git1c57644. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free licens

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

Unsloth: Making `model.base_model.model.model` require gradients


## 📊 Benchmark 1: Base Model Evaluation

Before any training, let's establish a baseline by benchmarking the untrained model.

In [3]:
# Push the base model to Ollama for benchmarking
# This assumes you have Ollama configured and the model uploaded
# You would typically do this via: ollama create gpt-oss-20b-base:latest -f Modelfile

print("📊 Running Base Model Benchmark...")
print("Model: ollama/gpt-oss:20b-cloud")
print("Samples: 100")
print("\nRun this command in your terminal:")
print("python scripts/generate_benchmark_report.py --model 'ollama/gpt-oss:20b-cloud' --samples 100")

📊 Running Base Model Benchmark...
Model: ollama/gpt-oss:20b-cloud
Samples: 100

Run this command in your terminal:
python scripts/generate_benchmark_report.py --model 'ollama/gpt-oss:20b-cloud' --samples 100


## Start DIPG Safety Gym Server

In [4]:
# Run the server setup script
!python scripts/setup_env.py

[2025-11-29 19:33:56] INFO 67128008.py:42: --- Ensuring port 8012 is free ---
[2025-11-29 19:34:00] INFO 67128008.py:62: ✅ Port is clear.

[2025-11-29 19:34:00] INFO 67128008.py:67: --- Resetting working directory and cloning repo ---
[2025-11-29 19:34:00] INFO 67128008.py:74: ✅ Setup complete. Current directory: /AIAC/med-safety-gym

[2025-11-29 19:34:00] INFO 67128008.py:78: ✅ Dataset path: surfiniaburger/dipg-sft-dataset
[2025-11-29 19:34:00] INFO 67128008.py:81: --- Installing project dependencies ---
[2025-11-29 19:34:02] INFO 67128008.py:85: ✅ Project dependencies installed (including openenv-core).

[2025-11-29 19:34:02] INFO 67128008.py:88: --- Installing Gunicorn ---
[2025-11-29 19:34:03] INFO 67128008.py:90: ✅ Gunicorn installed.

[2025-11-29 19:34:03] INFO 67128008.py:93: --- Starting DIPGSafetyEnv server on port 8012 ---
[2025-11-29 19:34:03] INFO 67128008.py:175: 
--- Waiting for server to become healthy... ---
[2025-11-29 19:34:13] INFO 67128008.py:182: ✅ Server is runnin

## Load and Prepare SFT Dataset

## Normalize Messages for Training

In [5]:
def formatting_prompts_func(examples):
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }

from datasets import load_dataset

dataset = load_dataset("surfiniaburger/dipg-sft-dataset", split="train")
dataset

Dataset({
    features: ['id', 'messages'],
    num_rows: 1000
})

In [6]:
from unsloth.chat_templates import standardize_data_formats
dataset = standardize_data_formats(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)

## Supervised Fine-Tuning (SFT)

In [7]:
from trl import SFTTrainer, SFTConfig
#from unsloth.chat_templates import train_on_responses_only

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 10,
        max_seq_length=4096,
        max_steps = 21,  # Adjust based on your dataset
        learning_rate = 2e-4,
        logging_steps = 10,
        optim = "adamw_8bit",
        weight_decay = 0.05,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "sft_outputs",
        report_to = "none",
    ),
)

# Train on responses only
#gpt_oss_kwargs = dict(instruction_part = "<|start|>user<|message|>", response_part="<|start|>assistant")
#trainer = train_on_responses_only(trainer, **gpt_oss_kwargs)

In [8]:
from unsloth.chat_templates import train_on_responses_only

gpt_oss_kwargs = dict(instruction_part = "<|start|>user<|message|>", response_part="<|start|>assistant")

trainer = train_on_responses_only(
    trainer,
    **gpt_oss_kwargs,
)

In [9]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-11-29\n\nReasoning: medium\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.\nCalls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>You are a specialized AI research assistant from the 'Neuro-Oncology Research Collective'. Your purpose is to provide clinicians and researchers with the most accurate and up-to-date information regarding Diffuse Intrinsic Pontine Glioma (DIPG).\nAll responses must be strictly derived from the provided context.\nCurrent date: 2025-11-17\n\n# Instructions\nYou are an expert AI assistant specializing in medical information. You must reason about the user's request step-by-step and then provide a final, conclusive answer. Your response must be grounded in the provided text only. Do not use outside knowledge.\n\n**CONTEXT:**\nThe experimental

In [10]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

'                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

In [11]:
print("--- Starting SFT Training ---")
trainer.train()
print("--- SFT Training Complete ---")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 199998}.


--- Starting SFT Training ---


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,000 | Num Epochs = 1 | Total steps = 21
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 31,850,496 of 20,946,607,680 (0.15% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,6.88
20,0.2646


--- SFT Training Complete ---


## 📊 Benchmark 2: Post-SFT Evaluation

After SFT, benchmark the model to measure improvement.

In [None]:
import os
import sys
import json
import asyncio
import nest_asyncio
import matplotlib.pyplot as plt
from IPython.display import display, Image
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
# Ensure we can import scripts from root
sys.path.append(os.getcwd())
sys.path.append(os.path.abspath('..'))
from scripts.visualizations import save_all_visualizations

# Apply nest_asyncio to allow async in notebook
nest_asyncio.apply()

print("📊 Running Post-SFT Benchmark via MCP Server...")

# 1. Setup MCP Server Parameters (Pointing to EVAL dataset)
eval_env = os.environ.copy()
eval_env["DIPG_DATASET_PATH"] = "surfiniaburger/dipg-eval-dataset"

server_params = StdioServerParameters(
    command=sys.executable,
    args=["-m", "server.mcp_server"],
    env=eval_env
)

async def run_evaluation(num_samples=100):
    print(f"Starting MCP server with dataset: {eval_env['DIPG_DATASET_PATH']}")
    
    async with stdio_client(server_params) as (read, write):
        async with ClientSession(read, write) as session:
            await session.initialize()
            
            # 2. Fetch Tasks
            print(f"Fetching {num_samples} tasks...")
            result = await session.call_tool("get_eval_tasks", arguments={"max_samples": num_samples})
            tasks_data = json.loads(result.content[0].text)
            tasks = tasks_data["tasks"]
            print(f"✅ Retrieved {len(tasks)} tasks.")
            
            # 3. Generate Responses
            print("Generating responses...")
            evaluations = []
            
            # Enable inference mode for unsloth model
            try:
                FastLanguageModel.for_inference(model)
            except NameError:
                print("⚠️ 'model' not defined. Assuming testing mode or model loaded elsewhere.")
            
            for i, task in enumerate(tasks):
                if i % 10 == 0: print(f"  Processing {i}/{len(tasks)}...")
                
                # Create prompt
                messages = [{"role": "user", "content": task["context"] + "\n\n" + task["question"]}]
                inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
                
                # Generate
                outputs = model.generate(input_ids=inputs, max_new_tokens=512, use_cache=True, pad_token_id=tokenizer.eos_token_id)
                decoded = tokenizer.batch_decode(outputs)
                
                # Extract response, handling cases where the template might be missing.
                assistant_token = "<|start|>assistant<|message|>"
                response_parts = decoded[0].split(assistant_token)
                if len(response_parts) > 1:
                    response_text = response_parts[-1].replace("<|end|>", "").strip()
                else:
                    response_text = "" # Or handle as an error
                    print(f"Warning: Assistant start token not found in output for task {i}.")
                
                evaluations.append({
                    "response": response_text,
                    "ground_truth": {
                        "context": task["context"],
                        "question": task["question"],
                        "expected_answer": task["expected_answer"]
                    }
                })
            
            # 4. Evaluate
            print("Evaluating responses...")
            eval_result = await session.call_tool("evaluate_batch", arguments={"evaluations": evaluations})
            metrics = json.loads(eval_result.content[0].text)
            return metrics

# Run the evaluation
metrics = asyncio.run(run_evaluation(num_samples=100))

# 5. Display Results
print("\n" + "="*40)
print("BENCHMARK RESULTS")
print("="*40)
print(f"Mean Reward: {metrics['mean_reward']:.2f}")
print(f"Safe Response Rate: {metrics['safe_response_rate']:.1%}")
print(f"Hallucination Rate: {metrics['medical_hallucination_rate']:.1%}")

# 6. Generate Visualizations
output_dir = "benchmark_results_sft"
saved_files = save_all_visualizations(metrics, output_dir, "SFT_Model")

print(f"\nVisualizations saved to {output_dir}/")
for file in saved_files:
    display(Image(filename=file))


In [None]:
# Save the SFT model and push to Ollama
model.save_pretrained("sft_model")
tokenizer.save_pretrained("sft_model")

print("📊 Running Post-SFT Benchmark...")
print("Model: ollama/gpt-oss-20b-sft:latest")
print("Samples: 100")
print("\nAfter pushing to Ollama, run:")
print("python scripts/generate_benchmark_report.py --model 'ollama/gpt-oss-20b-sft:latest' --samples 100")

## GRPO Training (Reinforcement Learning)

In [None]:
from envs.dipg_safety_env.models import DIPGAction
from requests.exceptions import ConnectionError

def create_reward_fn(environment):
    """
    Create reward function that interfaces with DIPG Safety Gym.
    """
    def get_reward_from_environment(completions, prompts, **kwargs):
        scores = []
        for i, response in enumerate(completions):
            try:
                result = environment.step(DIPGAction(llm_response=response))
                scores.append(result.reward)
            except ConnectionError as e:
                print(f"\n{'!'*80}")
                print(f"FATAL: Connection lost while processing completion #{i}.")
                print(f"Server crashed. Check logs.")
                print(f"{'!'*80}\n")
                scores.append(-50.0)
        return scores
    return get_reward_from_environment

reward_fn = create_reward_fn(env)

In [None]:
from trl import GRPOConfig, GRPOTrainer

# Prepare prompts for RL
prompts = [
    tokenizer.apply_chat_template(
        example["messages"][:-1],  # Exclude assistant response
        tokenize=False,
        add_generation_prompt=True
    ) for example in dataset["train"]
]

grpo_config = GRPOConfig(
    output_dir="grpo_outputs",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=5e-7,
    max_steps=1000,
    logging_steps=10,
    save_steps=100,
    report_to="wandb",
)

grpo_trainer = GRPOTrainer(
    model=model,
    config=grpo_config,
    tokenizer=tokenizer,
    reward_function=reward_fn,
)

print("--- Starting GRPO Training ---")
grpo_trainer.train(prompts)
print("--- GRPO Training Complete ---")

## 📊 Benchmark 3: Post-GRPO Evaluation

Final benchmark to measure the complete training pipeline.

In [None]:
# Save the final model and push to Ollama
model.save_pretrained("grpo_model")
tokenizer.save_pretrained("grpo_model")

print("📊 Running Post-GRPO Benchmark...")
print("Model: ollama/gpt-oss-20b-grpo:latest")
print("Samples: 100")
print("\nAfter pushing to Ollama, run:")
print("python scripts/generate_benchmark_report.py --model 'ollama/gpt-oss-20b-grpo:latest' --samples 100")

## 📈 Compare Results

After running all three benchmarks, compare the results:

```bash
# View all benchmark results
ls -lh benchmark_results/

# Compare metrics across stages
cat benchmark_results/ollama_gpt-oss:20b-cloud_results.json | grep -E '(mean_reward|safe_response_rate|medical_hallucination_rate)'
cat benchmark_results/ollama_gpt-oss-20b-sft:latest_results.json | grep -E '(mean_reward|safe_response_rate|medical_hallucination_rate)'
cat benchmark_results/ollama_gpt-oss-20b-grpo:latest_results.json | grep -E '(mean_reward|safe_response_rate|medical_hallucination_rate)'
```

Expected progression:
- **Base Model**: Low safe response rate, high hallucination rate
- **Post-SFT**: Improved format adherence, better grounding
- **Post-GRPO**: Highest safe response rate, lowest hallucination rate