# PhysRobot Phase 1 Ablation Study

**Self-contained Colab notebook** for the core ablation campaign.

- **Agents**: Pure PPO, GNS, PhysRobot-SV (full), No-EdgeFrame, HNN
- **Training**: 500K steps x 5 seeds per agent
- **Evaluation**: 100-episode in-distribution + OOD mass sweep
- **Hardware**: Colab T4 GPU (~7.5h total)

All code is inline. No external file dependencies.

---
**Paper**: PhysRobot -- Physics-Informed Robot Manipulation via SV Message Passing  
**Date**: 2026-02-06  
**Target**: ICRA 2027 / CoRL 2026

## Cell 1: Install Dependencies

In [None]:
# ============================================================
# Cell 1: Install Dependencies
# ============================================================
# Pinned versions for reproducibility

!pip install -q \
    torch==2.2.0 \
    stable-baselines3==2.3.0 \
    gymnasium==0.29.1 \
    numpy==1.26.4 \
    matplotlib==3.8.3 \
    scipy==1.12.0 \
    pandas==2.2.0

import torch
print(f"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU detected. Training will be slow.")

## Cell 2: Mount Google Drive

In [None]:
# ============================================================
# Cell 2: Mount Google Drive for persistent storage
# ============================================================

import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_ROOT = '/content/drive/MyDrive/PhysRobot/phase1_ablation'
    IN_COLAB = True
    print(f"Google Drive mounted. Results -> {SAVE_ROOT}")
except ImportError:
    SAVE_ROOT = './results/phase1_ablation'
    IN_COLAB = False
    print(f"Not in Colab. Results -> {SAVE_ROOT}")

os.makedirs(SAVE_ROOT, exist_ok=True)
os.makedirs(f'{SAVE_ROOT}/models', exist_ok=True)
os.makedirs(f'{SAVE_ROOT}/logs', exist_ok=True)
os.makedirs(f'{SAVE_ROOT}/figures', exist_ok=True)

print(f"Output directory: {SAVE_ROOT}")
print(f"Subdirs: models/, logs/, figures/")

## Cell 3: PushBox Environment (Unified)

In [None]:
# ============================================================
# Cell 3: PushBox Environment (Unified)
# ============================================================
# Analytical 2D physics (no MuJoCo dependency).
# 2-DOF planar arm pushes a single box to a goal.
#
# Observation (16-dim):
#   [0:2]   joint_pos (shoulder, elbow)
#   [2:4]   joint_vel
#   [4:7]   ee_pos (3D, z=0)
#   [7:10]  box_pos (3D, z=0)
#   [10:13] box_vel (3D, z=0)
#   [13:16] goal_pos (3D, z=0)
#
# Action (2-dim): shoulder/elbow torques in [-10, 10] Nm

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from typing import Optional, Dict, Any, Tuple


class PushBoxEnv(gym.Env):
    """Unified PushBox environment with 16-dim observations."""

    metadata = {'render_modes': ['rgb_array']}

    # Arm parameters
    L1 = 0.3   # link 1 length (m)
    L2 = 0.25  # link 2 length (m)

    # Physics parameters
    DT = 0.02            # timestep (s)
    SUBSTEPS = 5         # physics substeps per env step
    FRICTION = 0.5       # ground friction
    RESTITUTION = 0.4    # bounce factor
    EE_RADIUS = 0.03     # end-effector radius (m)
    BOX_RADIUS = 0.05    # box radius (m)

    def __init__(
        self,
        box_mass: float = 0.5,
        success_threshold: float = 0.15,
        max_episode_steps: int = 500,
        render_mode: Optional[str] = None,
    ):
        super().__init__()
        self.box_mass = box_mass
        self.success_threshold = success_threshold
        self.max_episode_steps = max_episode_steps
        self.render_mode = render_mode

        # Observation: 16-dim
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32
        )
        # Action: 2 joint torques
        self.action_space = spaces.Box(
            low=-10.0, high=10.0, shape=(2,), dtype=np.float32
        )

        self._step_count = 0
        self._prev_box_goal_dist = None

    def _forward_kinematics(self, q):
        """Compute end-effector position from joint angles."""
        x = self.L1 * np.cos(q[0]) + self.L2 * np.cos(q[0] + q[1])
        y = self.L1 * np.sin(q[0]) + self.L2 * np.sin(q[0] + q[1])
        return np.array([x, y, 0.0])  # 3D, z=0

    def _jacobian(self, q):
        """2x2 Jacobian of the 2-DOF arm."""
        s1 = np.sin(q[0])
        c1 = np.cos(q[0])
        s12 = np.sin(q[0] + q[1])
        c12 = np.cos(q[0] + q[1])
        return np.array([
            [-self.L1 * s1 - self.L2 * s12, -self.L2 * s12],
            [ self.L1 * c1 + self.L2 * c12,  self.L2 * c12],
        ])

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        # Override box_mass if requested
        if options and 'box_mass' in options:
            self.box_mass = options['box_mass']

        # Random joint angles
        self.q = self.np_random.uniform(
            low=[-0.5, -0.5], high=[0.5, 0.5]
        ).astype(np.float64)
        self.qd = np.zeros(2, dtype=np.float64)

        # Box initial position (reachable workspace)
        self.box_pos = np.array([
            self.np_random.uniform(0.25, 0.45),
            self.np_random.uniform(-0.15, 0.15),
            0.0,
        ], dtype=np.float64)
        self.box_vel = np.zeros(3, dtype=np.float64)

        # Goal position
        self.goal_pos = np.array([
            self.np_random.uniform(0.2, 0.5),
            self.np_random.uniform(-0.2, 0.2),
            0.0,
        ], dtype=np.float64)

        # Ensure goal is not too close to initial box position
        while np.linalg.norm(self.box_pos - self.goal_pos) < 0.15:
            self.goal_pos[:2] = self.np_random.uniform(
                low=[0.2, -0.2], high=[0.5, 0.2]
            )

        self._step_count = 0
        self._prev_box_goal_dist = np.linalg.norm(
            self.box_pos - self.goal_pos
        )

        return self._get_obs(), self._get_info()

    def step(self, action):
        action = np.clip(action, -10.0, 10.0)

        # Physics substeps
        dt_sub = self.DT / self.SUBSTEPS
        for _ in range(self.SUBSTEPS):
            # Joint dynamics (simplified: direct torque -> acceleration)
            # M(q) * qdd = tau - friction * qd
            # Simplification: M = I (unit inertia)
            qdd = action - 2.0 * self.qd  # damping
            self.qd += qdd * dt_sub
            self.q += self.qd * dt_sub

            # Compute EE position
            ee_pos = self._forward_kinematics(self.q)
            J = self._jacobian(self.q)
            ee_vel_2d = J @ self.qd
            ee_vel = np.array([ee_vel_2d[0], ee_vel_2d[1], 0.0])

            # Contact: EE -> Box
            delta = self.box_pos - ee_pos
            dist = np.linalg.norm(delta)
            contact_dist = self.EE_RADIUS + self.BOX_RADIUS

            if dist < contact_dist and dist > 1e-8:
                n = delta / dist
                v_rel = ee_vel - self.box_vel
                v_n = np.dot(v_rel, n)
                if v_n > 0:  # approaching
                    # Impulse (elastic collision)
                    j_imp = (1 + self.RESTITUTION) * v_n / (1.0 + 1.0 / self.box_mass)
                    self.box_vel += (j_imp / self.box_mass) * n

                # Separate overlap
                overlap = contact_dist - dist
                if overlap > 0:
                    self.box_pos += n * (overlap + 0.001)

            # Box friction
            self.box_vel *= (1.0 - self.FRICTION * dt_sub)
            self.box_pos += self.box_vel * dt_sub

        # Reward (V2 design from EXPERIMENT_DESIGN.md)
        ee_pos = self._forward_kinematics(self.q)
        box_goal_dist = np.linalg.norm(self.box_pos - self.goal_pos)
        ee_box_dist = np.linalg.norm(ee_pos - self.box_pos)

        # Progress reward: reduction in box-goal distance
        progress = self._prev_box_goal_dist - box_goal_dist
        self._prev_box_goal_dist = box_goal_dist

        # Approach reward: encourage EE to get close to box
        approach = -0.1 * ee_box_dist

        # Action penalty
        action_cost = -0.001 * np.sum(action ** 2)

        reward = 10.0 * progress + approach + action_cost

        # Success check
        success = box_goal_dist < self.success_threshold
        if success:
            reward += 500.0

        self._step_count += 1
        terminated = bool(success)
        truncated = self._step_count >= self.max_episode_steps

        return self._get_obs(), float(reward), terminated, truncated, self._get_info()

    def _get_obs(self):
        ee_pos = self._forward_kinematics(self.q)
        return np.concatenate([
            self.q,          # [0:2]   joint positions
            self.qd,         # [2:4]   joint velocities
            ee_pos,          # [4:7]   end-effector pos (3D)
            self.box_pos,    # [7:10]  box pos (3D)
            self.box_vel,    # [10:13] box vel (3D)
            self.goal_pos,   # [13:16] goal pos (3D)
        ]).astype(np.float32)

    def _get_info(self):
        box_goal_dist = np.linalg.norm(self.box_pos - self.goal_pos)
        return {
            'distance_to_goal': float(box_goal_dist),
            'success': bool(box_goal_dist < self.success_threshold),
            'box_mass': float(self.box_mass),
            'step': self._step_count,
        }


def make_push_box_env(box_mass=0.5, success_threshold=0.15, seed=None):
    """Factory function returning a callable that creates PushBoxEnv."""
    def _init():
        env = PushBoxEnv(
            box_mass=box_mass,
            success_threshold=success_threshold,
        )
        return env
    return _init


# ---------- Quick validation ----------
env = PushBoxEnv()
obs, info = env.reset(seed=42)
assert obs.shape == (16,), f"Expected (16,), got {obs.shape}"
for _ in range(10):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)

print(f"PushBoxEnv OK: obs={obs.shape}, reward={reward:.3f}, "
      f"dist={info['distance_to_goal']:.3f}")

## Cell 4: Agent Definitions (All Variants)

In [None]:
# ============================================================
# Cell 4: Agent Definitions
# ============================================================
# All feature extractors + agent creation for:
#   V1: Pure PPO          (MLP, ~10K params)
#   V2: GNS               (Graph without physics, ~5K params)
#   V3: PhysRobot-SV      (SV message passing, ~7.5K params)
#   V5: No-EdgeFrame      (global frame physics MLP, ~5K params)
#   B5: HNN               (Hamiltonian NN, ~10K params)

import torch
import torch.nn as nn
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import gymnasium as gym

EPS = 1e-7
DEG_EPS = 1e-4


# ============================
# Helper: 2-layer MLP
# ============================
def _make_mlp(in_dim, hidden_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.LayerNorm(hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, out_dim),
    )


# ============================
# SV Message Passing (from sv_message_passing.py)
# ============================

def build_edge_frames(pos, vel, src, dst):
    """Construct antisymmetric edge-local coordinate frames."""
    r_ij = pos[dst] - pos[src]
    d_ij = torch.norm(r_ij, dim=-1, keepdim=True)
    e1 = r_ij / (d_ij + EPS)

    v_rel = vel[dst] - vel[src]
    v_par = (v_rel * e1).sum(dim=-1, keepdim=True) * e1
    v_perp = v_rel - v_par
    v_perp_norm = torch.norm(v_perp, dim=-1, keepdim=True)

    non_degenerate = (v_perp_norm > DEG_EPS).float()
    e2_vel = v_perp / (v_perp_norm + EPS)

    z_hat = torch.tensor([0.0, 0.0, 1.0], device=pos.device).expand_as(e1)
    e2_fall_raw = torch.cross(e1, z_hat, dim=-1)
    e2_fall_norm = torch.norm(e2_fall_raw, dim=-1, keepdim=True)

    use_y = (e2_fall_norm < DEG_EPS).float()
    y_hat = torch.tensor([0.0, 1.0, 0.0], device=pos.device).expand_as(e1)
    e2_fall_raw2 = torch.cross(e1, y_hat, dim=-1)
    e2_fall_raw = (1 - use_y) * e2_fall_raw + use_y * e2_fall_raw2
    e2_fall_norm = torch.norm(e2_fall_raw, dim=-1, keepdim=True)
    e2_fall = e2_fall_raw / (e2_fall_norm + EPS)

    e2 = non_degenerate * e2_vel + (1 - non_degenerate) * e2_fall
    e3 = torch.cross(e1, e2, dim=-1)

    return e1, e2, e3, r_ij, d_ij


class SVMessagePassing(nn.Module):
    """One round of SV message passing with Newton's 3rd law hard-coded."""

    def __init__(self, node_dim, hidden_dim=32):
        super().__init__()
        n_scalar = 5  # ||r||, v_r, v_t, v_b, ||v_rel||
        self.force_mlp = _make_mlp(
            in_dim=n_scalar + 2 * node_dim,
            hidden_dim=hidden_dim,
            out_dim=3,
        )
        self.node_update = _make_mlp(
            in_dim=node_dim + 3,
            hidden_dim=hidden_dim,
            out_dim=node_dim,
        )

    def _extract_undirected_pairs(self, edge_index):
        src, dst = edge_index[0], edge_index[1]
        mask = src < dst
        return torch.stack([src[mask], dst[mask]], dim=0)

    def forward_with_forces(self, h, edge_index, pos, vel):
        N = h.size(0)
        pairs = self._extract_undirected_pairs(edge_index)
        pi, pj = pairs[0], pairs[1]

        e1, e2, e3, r_ij, d_ij = build_edge_frames(pos, vel, pi, pj)
        v_rel = vel[pj] - vel[pi]

        v_r = (v_rel * e1).sum(dim=-1, keepdim=True)
        v_t = (v_rel * e2).sum(dim=-1, keepdim=True)
        v_b = (v_rel * e3).sum(dim=-1, keepdim=True)
        v_norm = torch.norm(v_rel, dim=-1, keepdim=True)

        scalars_geom = torch.cat([d_ij, v_r, v_t, v_b, v_norm], dim=-1)
        h_sum = h[pi] + h[pj]
        h_diff_abs = (h[pi] - h[pj]).abs()
        scalars = torch.cat([scalars_geom, h_sum, h_diff_abs], dim=-1)

        alphas = self.force_mlp(scalars)
        alpha1, alpha2, alpha3 = alphas[:, 0:1], alphas[:, 1:2], alphas[:, 2:3]

        force_ij = alpha1 * e1 + alpha2 * e2 + alpha3 * e3

        F_agg = torch.zeros(N, 3, device=h.device, dtype=h.dtype)
        F_agg.scatter_add_(0, pj.unsqueeze(-1).expand_as(force_ij), force_ij)
        F_agg.scatter_add_(0, pi.unsqueeze(-1).expand_as(force_ij), -force_ij)

        h_input = torch.cat([h, F_agg], dim=-1)
        h_new = h + self.node_update(h_input)
        return h_new, F_agg


class SVPhysicsCore(nn.Module):
    """Complete physics stream using SV pipeline. Conservation guaranteed."""

    def __init__(self, node_input_dim=6, hidden_dim=32, n_layers=1):
        super().__init__()
        self.encoder = _make_mlp(node_input_dim, hidden_dim, hidden_dim)
        self.sv_layers = nn.ModuleList([
            SVMessagePassing(node_dim=hidden_dim, hidden_dim=hidden_dim)
            for _ in range(n_layers)
        ])

    def forward(self, positions, velocities, edge_index):
        node_features = torch.cat([positions, velocities], dim=-1)
        h = self.encoder(node_features)
        forces = None
        for layer in self.sv_layers:
            h, F_agg = layer.forward_with_forces(h, edge_index, positions, velocities)
            forces = F_agg
        return forces


# ============================
# V3: PhysRobot-SV Features Extractor
# ============================

class PhysRobotSVExtractor(BaseFeaturesExtractor):
    """Dual-stream: SV-physics (stop-gradient) + policy MLP."""

    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        self.physics_core = SVPhysicsCore(
            node_input_dim=6, hidden_dim=32, n_layers=1,
        )
        self.policy_stream = nn.Sequential(
            nn.Linear(16, 64), nn.ReLU(),
            nn.Linear(64, features_dim), nn.ReLU(),
        )
        self.fusion = nn.Sequential(
            nn.Linear(features_dim + 3, features_dim), nn.ReLU(),
        )
        self._edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long).t()

    def forward(self, observations):
        B = observations.shape[0]
        device = observations.device

        z_policy = self.policy_stream(observations)

        ee_pos = observations[:, 4:7]
        box_pos = observations[:, 7:10]
        box_vel = observations[:, 10:13]
        ee_vel = torch.zeros_like(ee_pos)

        edge_index = self._edge_index.to(device)
        box_acc_list = []
        for i in range(B):
            pos_i = torch.stack([ee_pos[i], box_pos[i]], dim=0)
            vel_i = torch.stack([ee_vel[i], box_vel[i]], dim=0)
            acc_i = self.physics_core(pos_i, vel_i, edge_index)
            box_acc_list.append(acc_i[1])

        z_physics = torch.stack(box_acc_list, dim=0).detach()  # stop-gradient
        combined = torch.cat([z_policy, z_physics], dim=-1)
        return self.fusion(combined)


# ============================
# V2: GNS Features Extractor
# ============================

class GNSExtractor(BaseFeaturesExtractor):
    """Graph network features extractor (no physics constraints)."""

    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        # Node encoder: 6-dim -> 32-dim
        self.node_encoder = nn.Sequential(
            nn.Linear(6, 32), nn.ReLU(),
        )
        # Edge encoder: 4-dim -> 32-dim
        self.edge_encoder = nn.Sequential(
            nn.Linear(4, 32), nn.ReLU(),
        )
        # Message passing: aggregate edge messages, update node
        self.message_mlp = nn.Sequential(
            nn.Linear(32 + 32, 32), nn.ReLU(),
            nn.Linear(32, 32),
        )
        self.node_update_mlp = nn.Sequential(
            nn.Linear(32 + 32, 32), nn.ReLU(),
            nn.Linear(32, 32),
        )
        # Decoder: box node embedding -> 3-dim prediction
        self.decoder = nn.Linear(32, 3)
        # Feature projection: raw obs -> features
        self.feature_proj = nn.Sequential(
            nn.Linear(16 + 3, features_dim), nn.ReLU(),
        )

    def forward(self, observations):
        B = observations.shape[0]
        device = observations.device

        preds = []
        for i in range(B):
            obs = observations[i]
            ee_pos = obs[4:7]
            box_pos = obs[7:10]
            box_vel = obs[10:13]
            ee_vel = torch.zeros(3, device=device)

            # Node features: [pos(3), vel(3)]
            node_feats = torch.stack([
                torch.cat([ee_pos, ee_vel]),
                torch.cat([box_pos, box_vel]),
            ])  # [2, 6]

            h = self.node_encoder(node_feats)  # [2, 32]

            # Edge features: rel_pos(3) + dist(1)
            rel_01 = box_pos - ee_pos
            d_01 = torch.norm(rel_01).unsqueeze(0)
            e_01 = torch.cat([rel_01, d_01]).unsqueeze(0)  # [1, 4]
            rel_10 = ee_pos - box_pos
            d_10 = torch.norm(rel_10).unsqueeze(0)
            e_10 = torch.cat([rel_10, d_10]).unsqueeze(0)  # [1, 4]

            e_enc_01 = self.edge_encoder(e_01)  # [1, 32]
            e_enc_10 = self.edge_encoder(e_10)  # [1, 32]

            # Messages: node[src] + edge -> message
            msg_01 = self.message_mlp(torch.cat([h[0:1], e_enc_01], dim=-1))  # 0->1
            msg_10 = self.message_mlp(torch.cat([h[1:2], e_enc_10], dim=-1))  # 1->0

            # Update nodes
            h0_new = h[0:1] + self.node_update_mlp(torch.cat([h[0:1], msg_10], dim=-1))
            h1_new = h[1:2] + self.node_update_mlp(torch.cat([h[1:2], msg_01], dim=-1))

            # Decode box node
            pred = self.decoder(h1_new).squeeze(0)  # [3]
            preds.append(pred)

        preds = torch.stack(preds)  # [B, 3]
        combined = torch.cat([observations, preds], dim=-1)  # [B, 19]
        return self.feature_proj(combined)


# ============================
# V5: No-EdgeFrame (Global Frame MLP)
# ============================

class NoEdgeFrameExtractor(BaseFeaturesExtractor):
    """Physics MLP in global frame (no relative-geometry features)."""

    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        # Global frame physics: [ee_pos(3), box_pos(3), box_vel(3), goal(3)] -> acc(3)
        self.physics_net = nn.Sequential(
            nn.Linear(12, 32), nn.ReLU(),
            nn.Linear(32, 3),
        )
        self.policy_stream = nn.Sequential(
            nn.Linear(16, 64), nn.ReLU(),
            nn.Linear(64, features_dim), nn.ReLU(),
        )
        self.fusion = nn.Sequential(
            nn.Linear(features_dim + 3, features_dim), nn.ReLU(),
        )

    def forward(self, observations):
        z_policy = self.policy_stream(observations)
        physics_input = torch.cat([
            observations[:, 4:7],   # ee_pos
            observations[:, 7:10],  # box_pos
            observations[:, 10:13], # box_vel
            observations[:, 13:16], # goal_pos
        ], dim=-1)
        z_physics = self.physics_net(physics_input).detach()
        combined = torch.cat([z_policy, z_physics], dim=-1)
        return self.fusion(combined)


# ============================
# B5: HNN Features Extractor
# ============================

class HNNExtractor(BaseFeaturesExtractor):
    """Hamiltonian Neural Network as PPO feature extractor."""

    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        self.H_net = nn.Sequential(
            nn.Linear(6, 64), nn.Softplus(),
            nn.Linear(64, 64), nn.Softplus(),
            nn.Linear(64, 1),
        )
        self.feature_proj = nn.Sequential(
            nn.Linear(3 + 16, features_dim), nn.ReLU(),
        )

    def forward(self, observations):
        ee_pos = observations[:, 4:7]
        box_pos = observations[:, 7:10]
        box_vel = observations[:, 10:13]
        q = box_pos - ee_pos
        p = box_vel

        qp = torch.cat([q, p], dim=-1)
        qp = qp.detach().requires_grad_(True)
        H = self.H_net(qp)
        dH = torch.autograd.grad(
            H.sum(), qp, create_graph=True, retain_graph=True
        )[0]
        acc = -dH[:, :3]  # dp/dt = -dH/dq -> acceleration

        combined = torch.cat([acc, observations], dim=-1)
        return self.feature_proj(combined)


# ============================
# Agent Factory
# ============================

AGENT_CONFIGS = {
    'pure_ppo': {
        'description': 'Pure PPO (MLP, no physics)',
        'extractor': None,
    },
    'gns': {
        'description': 'GNS (graph without physics constraints)',
        'extractor': GNSExtractor,
    },
    'physrobot_sv': {
        'description': 'PhysRobot-SV (momentum-conserving, our method)',
        'extractor': PhysRobotSVExtractor,
    },
    'no_edgeframe': {
        'description': 'No-EdgeFrame (global frame physics MLP)',
        'extractor': NoEdgeFrameExtractor,
    },
    'hnn': {
        'description': 'HNN (Hamiltonian energy conservation)',
        'extractor': HNNExtractor,
    },
}


def create_agent(agent_name, env, seed=42):
    """Create a PPO agent with the specified feature extractor."""
    cfg = AGENT_CONFIGS[agent_name]

    policy_kwargs = dict(
        net_arch=dict(pi=[64, 64], vf=[64, 64]),
    )

    if cfg['extractor'] is not None:
        policy_kwargs['features_extractor_class'] = cfg['extractor']
        policy_kwargs['features_extractor_kwargs'] = dict(features_dim=64)

    model = PPO(
        'MlpPolicy',
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.01,
        vf_coef=0.5,
        max_grad_norm=0.5,
        seed=seed,
        verbose=0,
        policy_kwargs=policy_kwargs,
        device='auto',
    )
    n_params = sum(p.numel() for p in model.policy.parameters())
    print(f"  Created {agent_name}: {n_params:,} params | {cfg['description']}")
    return model


# ---------- Verify all agents can be created ----------
from stable_baselines3.common.vec_env import DummyVecEnv
_test_env = DummyVecEnv([make_push_box_env()])
print("Agent creation test:")
for name in AGENT_CONFIGS:
    _m = create_agent(name, _test_env, seed=0)
    del _m
_test_env.close()
print("All agents OK.")

## Cell 5: Training Loop (500K steps x 5 seeds)

In [None]:
# ============================================================
# Cell 5: Training Loop
# ============================================================
# 500K steps x 5 seeds for each agent variant.
# Checkpoints saved every 100K steps.
# Training logs saved per-run.

import time
import json
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback


# ---- Config ----
SEEDS = [42, 123, 256, 789, 1024]
TOTAL_TIMESTEPS = 500_000
N_ENVS = 4
CHECKPOINT_FREQ = 100_000
EVAL_FREQ = 50_000     # evaluate every 50K steps
EVAL_EPISODES = 20     # quick eval during training

AGENTS_TO_TRAIN = ['pure_ppo', 'gns', 'physrobot_sv', 'no_edgeframe', 'hnn']


# ---- Callbacks ----
class TrainingLogger(BaseCallback):
    """Log rewards, successes, and save checkpoints."""

    def __init__(self, save_dir, agent_name, seed,
                 checkpoint_freq=100_000, eval_freq=50_000,
                 eval_episodes=20):
        super().__init__()
        self.save_dir = save_dir
        self.agent_name = agent_name
        self.seed = seed
        self.checkpoint_freq = checkpoint_freq
        self.eval_freq = eval_freq
        self.eval_episodes = eval_episodes

        self.episode_rewards = []
        self.episode_successes = []
        self.episode_lengths = []
        self.eval_log = []  # (timestep, success_rate, mean_reward)
        self._current_rewards = None

    def _on_training_start(self):
        n = self.training_env.num_envs
        self._current_rewards = np.zeros(n)
        self._current_lengths = np.zeros(n, dtype=int)

    def _on_step(self):
        # Accumulate per-env rewards
        rewards = self.locals.get('rewards', np.zeros(1))
        dones = self.locals.get('dones', np.zeros(1, dtype=bool))
        infos = self.locals.get('infos', [{}])

        self._current_rewards += rewards
        self._current_lengths += 1

        for i, done in enumerate(dones):
            if done:
                self.episode_rewards.append(float(self._current_rewards[i]))
                self.episode_lengths.append(int(self._current_lengths[i]))
                success = infos[i].get('success', False)
                self.episode_successes.append(1 if success else 0)
                self._current_rewards[i] = 0.0
                self._current_lengths[i] = 0

        # Checkpoint
        if self.num_timesteps % self.checkpoint_freq == 0 and self.num_timesteps > 0:
            path = f"{self.save_dir}/models/{self.agent_name}_s{self.seed}_{self.num_timesteps // 1000}k"
            self.model.save(path)

        # Quick eval
        if self.num_timesteps % self.eval_freq == 0 and self.num_timesteps > 0:
            sr, mr = self._quick_eval()
            self.eval_log.append({
                'timestep': self.num_timesteps,
                'success_rate': sr,
                'mean_reward': mr,
            })

        return True

    def _quick_eval(self):
        """Run quick evaluation (deterministic) on a fresh env."""
        eval_env = DummyVecEnv([make_push_box_env()])
        successes = []
        rewards_list = []
        for _ in range(self.eval_episodes):
            obs = eval_env.reset()
            done = False
            ep_reward = 0.0
            while not done:
                action, _ = self.model.predict(obs, deterministic=True)
                obs, reward, dones, infos = eval_env.step(action)
                ep_reward += reward[0]
                done = dones[0]
            successes.append(1 if infos[0].get('success', False) else 0)
            rewards_list.append(ep_reward)
        eval_env.close()
        return float(np.mean(successes)), float(np.mean(rewards_list))

    def get_log(self):
        return {
            'agent': self.agent_name,
            'seed': self.seed,
            'episode_rewards': self.episode_rewards,
            'episode_successes': self.episode_successes,
            'episode_lengths': self.episode_lengths,
            'eval_log': self.eval_log,
        }


# ---- Main Training ----
all_training_logs = {}
all_models = {}  # (agent_name, seed) -> model path

total_runs = len(AGENTS_TO_TRAIN) * len(SEEDS)
run_idx = 0
overall_start = time.time()

for agent_name in AGENTS_TO_TRAIN:
    for seed in SEEDS:
        run_idx += 1
        run_key = f"{agent_name}_s{seed}"

        # Check if already done (resume support)
        final_model_path = f"{SAVE_ROOT}/models/{agent_name}_s{seed}_final"
        log_path = f"{SAVE_ROOT}/logs/{run_key}.json"
        if os.path.exists(f"{final_model_path}.zip") and os.path.exists(log_path):
            print(f"[{run_idx}/{total_runs}] SKIP {run_key} (already done)")
            with open(log_path) as f:
                all_training_logs[run_key] = json.load(f)
            all_models[(agent_name, seed)] = final_model_path
            continue

        print(f"\n{'='*60}")
        print(f"[{run_idx}/{total_runs}] Training: {agent_name} | seed={seed}")
        print(f"{'='*60}")

        t0 = time.time()

        # Create vectorized training env
        train_env = DummyVecEnv([make_push_box_env() for _ in range(N_ENVS)])

        # Create agent
        model = create_agent(agent_name, train_env, seed=seed)

        # Create logger callback
        logger = TrainingLogger(
            save_dir=SAVE_ROOT,
            agent_name=agent_name,
            seed=seed,
            checkpoint_freq=CHECKPOINT_FREQ,
            eval_freq=EVAL_FREQ,
            eval_episodes=EVAL_EPISODES,
        )

        # Train
        try:
            model.learn(
                total_timesteps=TOTAL_TIMESTEPS,
                callback=logger,
                progress_bar=True,
            )
        except Exception as e:
            print(f"  ERROR during training: {e}")
            train_env.close()
            continue

        elapsed = time.time() - t0

        # Save final model
        model.save(final_model_path)
        all_models[(agent_name, seed)] = final_model_path

        # Save training log
        train_log = logger.get_log()
        train_log['training_time_s'] = elapsed
        train_log['total_timesteps'] = TOTAL_TIMESTEPS
        with open(log_path, 'w') as f:
            json.dump(train_log, f)
        all_training_logs[run_key] = train_log

        # Print summary
        recent_sr = np.mean(train_log['episode_successes'][-100:]) if train_log['episode_successes'] else 0
        recent_rew = np.mean(train_log['episode_rewards'][-100:]) if train_log['episode_rewards'] else 0
        print(f"  Done in {elapsed:.0f}s | Recent SR: {recent_sr:.1%} | "
              f"Recent reward: {recent_rew:.1f} | Episodes: {len(train_log['episode_rewards'])}")

        train_env.close()
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

total_elapsed = time.time() - overall_start
print(f"\n{'='*60}")
print(f"ALL TRAINING COMPLETE: {total_runs} runs in {total_elapsed/3600:.1f}h")
print(f"{'='*60}")

## Cell 6: Evaluation + OOD Testing

In [None]:
# ============================================================
# Cell 6: Evaluation + OOD Testing
# ============================================================
# In-distribution: 100 episodes, deterministic policy
# OOD: 7 mass values x 100 episodes each

import json
import time
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

# ---- Config ----
EVAL_EPISODES = 100
EVAL_SEEDS_START = 10000  # Fixed eval seeds for reproducibility
OOD_MASSES = [0.1, 0.25, 0.5, 0.75, 1.0, 2.0, 5.0]
OOD_FRICTIONS = [0.1, 0.3, 0.5, 0.7, 1.0]  # future use
TRAIN_MASS = 0.5  # in-distribution value


def evaluate_model(model, env, n_episodes=100, eval_seed_start=10000):
    """Evaluate model on environment with fixed seeds."""
    successes = []
    rewards = []
    distances = []
    first_success_ep = None

    for ep in range(n_episodes):
        obs = env.reset()
        done = False
        ep_reward = 0.0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, dones, infos = env.step(action)
            ep_reward += reward[0]
            done = dones[0]

        info = infos[0]
        success = info.get('success', False)
        successes.append(1 if success else 0)
        rewards.append(ep_reward)
        distances.append(info.get('distance_to_goal', 999))

        if success and first_success_ep is None:
            first_success_ep = ep

    return {
        'success_rate': float(np.mean(successes)),
        'success_std': float(np.std(successes)),
        'mean_reward': float(np.mean(rewards)),
        'std_reward': float(np.std(rewards)),
        'mean_distance': float(np.mean(distances)),
        'std_distance': float(np.std(distances)),
        'first_success_ep': first_success_ep,
        'n_episodes': n_episodes,
    }


# ---- Run Evaluation ----
all_eval_results = {}
eval_start = time.time()
total_evals = len(AGENTS_TO_TRAIN) * len(SEEDS)
eval_idx = 0

for agent_name in AGENTS_TO_TRAIN:
    for seed in SEEDS:
        eval_idx += 1
        run_key = f"{agent_name}_s{seed}"
        model_path = f"{SAVE_ROOT}/models/{agent_name}_s{seed}_final"

        # Check result cache
        eval_cache_path = f"{SAVE_ROOT}/logs/{run_key}_eval.json"
        if os.path.exists(eval_cache_path):
            print(f"[{eval_idx}/{total_evals}] SKIP {run_key} eval (cached)")
            with open(eval_cache_path) as f:
                all_eval_results[run_key] = json.load(f)
            continue

        print(f"[{eval_idx}/{total_evals}] Evaluating: {run_key}", end=" ")

        if not os.path.exists(f"{model_path}.zip"):
            print("-- MODEL NOT FOUND, skipping")
            continue

        # Load model
        eval_env_id = DummyVecEnv([make_push_box_env(box_mass=TRAIN_MASS)])
        model = PPO.load(model_path, env=eval_env_id)

        # In-distribution evaluation
        id_results = evaluate_model(model, eval_env_id, n_episodes=EVAL_EPISODES)
        print(f"ID SR={id_results['success_rate']:.1%}", end=" | ")

        # OOD evaluation (mass sweep)
        ood_results = {}
        for mass in OOD_MASSES:
            ood_env = DummyVecEnv([make_push_box_env(box_mass=mass)])
            ood_res = evaluate_model(model, ood_env, n_episodes=EVAL_EPISODES)
            ood_results[f'mass_{mass}'] = ood_res
            ood_env.close()

        eval_env_id.close()

        # Aggregate OOD
        ood_srs = [ood_results[k]['success_rate'] for k in ood_results]
        ood_mean = float(np.mean(ood_srs))
        ood_std = float(np.std(ood_srs))

        result = {
            'agent': agent_name,
            'seed': seed,
            'in_distribution': id_results,
            'ood': ood_results,
            'ood_mean_sr': ood_mean,
            'ood_std_sr': ood_std,
        }

        # Cache
        with open(eval_cache_path, 'w') as f:
            json.dump(result, f, indent=2)

        all_eval_results[run_key] = result
        print(f"OOD SR={ood_mean:.1%}")

        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

eval_elapsed = time.time() - eval_start
print(f"\nEvaluation complete in {eval_elapsed/60:.1f} min")

# ---- Print Summary Table ----
print(f"\n{'='*80}")
print(f"{'Agent':<18} {'ID SR':>10} {'OOD SR':>10} {'ID Reward':>12} {'ID Dist':>10}")
print(f"{'-'*80}")

for agent_name in AGENTS_TO_TRAIN:
    id_srs = []
    ood_srs = []
    id_rews = []
    id_dists = []
    for seed in SEEDS:
        rk = f"{agent_name}_s{seed}"
        if rk in all_eval_results:
            r = all_eval_results[rk]
            id_srs.append(r['in_distribution']['success_rate'])
            ood_srs.append(r['ood_mean_sr'])
            id_rews.append(r['in_distribution']['mean_reward'])
            id_dists.append(r['in_distribution']['mean_distance'])

    if id_srs:
        print(f"{agent_name:<18} "
              f"{np.mean(id_srs):.1%}+/-{np.std(id_srs):.1%}  "
              f"{np.mean(ood_srs):.1%}+/-{np.std(ood_srs):.1%}  "
              f"{np.mean(id_rews):>8.1f}+/-{np.std(id_rews):.1f}  "
              f"{np.mean(id_dists):>6.3f}")

print(f"{'='*80}")

## Cell 7: Results + Figure Generation

In [None]:
# ============================================================
# Cell 7: Results Visualization + Figure Generation
# ============================================================

import json
import numpy as np
import matplotlib
matplotlib.rcParams['font.size'] = 11
matplotlib.rcParams['axes.labelsize'] = 12
matplotlib.rcParams['axes.titlesize'] = 13
matplotlib.rcParams['legend.fontsize'] = 10
import matplotlib.pyplot as plt
from scipy import stats

COLORS = {
    'pure_ppo':      '#1f77b4',
    'gns':           '#ff7f0e',
    'physrobot_sv':  '#2ca02c',
    'no_edgeframe':  '#d62728',
    'hnn':           '#9467bd',
}
LABELS = {
    'pure_ppo':      'Pure PPO',
    'gns':           'GNS',
    'physrobot_sv':  'PhysRobot-SV (ours)',
    'no_edgeframe':  'No-EdgeFrame',
    'hnn':           'HNN',
}


# ========================================
# Figure 1: Bar chart -- Success rate comparison
# ========================================
def plot_success_bars():
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # In-distribution
    ax = axes[0]
    names = []
    means = []
    stds = []
    colors = []
    for agent_name in AGENTS_TO_TRAIN:
        srs = [all_eval_results.get(f"{agent_name}_s{s}", {}).get(
            'in_distribution', {}).get('success_rate', 0)
               for s in SEEDS]
        srs = [x for x in srs if x is not None]
        if srs:
            names.append(LABELS.get(agent_name, agent_name))
            means.append(np.mean(srs))
            stds.append(np.std(srs))
            colors.append(COLORS.get(agent_name, '#888888'))

    x = np.arange(len(names))
    bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, alpha=0.85, edgecolor='black')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=25, ha='right')
    ax.set_ylabel('Success Rate')
    ax.set_title('In-Distribution (mass=0.5kg)')
    ax.set_ylim(0, 1.05)
    ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='50%')
    ax.legend()

    # OOD
    ax = axes[1]
    names2 = []
    means2 = []
    stds2 = []
    colors2 = []
    for agent_name in AGENTS_TO_TRAIN:
        srs = [all_eval_results.get(f"{agent_name}_s{s}", {}).get(
            'ood_mean_sr', 0) for s in SEEDS]
        srs = [x for x in srs if x is not None]
        if srs:
            names2.append(LABELS.get(agent_name, agent_name))
            means2.append(np.mean(srs))
            stds2.append(np.std(srs))
            colors2.append(COLORS.get(agent_name, '#888888'))

    x = np.arange(len(names2))
    ax.bar(x, means2, yerr=stds2, capsize=5, color=colors2, alpha=0.85, edgecolor='black')
    ax.set_xticks(x)
    ax.set_xticklabels(names2, rotation=25, ha='right')
    ax.set_ylabel('Success Rate')
    ax.set_title('OOD (mass sweep: 0.1-5.0 kg)')
    ax.set_ylim(0, 1.05)

    plt.tight_layout()
    plt.savefig(f"{SAVE_ROOT}/figures/success_rates.png", dpi=150, bbox_inches='tight')
    plt.savefig(f"{SAVE_ROOT}/figures/success_rates.pdf", bbox_inches='tight')
    plt.show()
    print("Saved: success_rates.png / .pdf")


# ========================================
# Figure 2: OOD generalization curves
# ========================================
def plot_ood_curves():
    fig, ax = plt.subplots(figsize=(8, 5))

    for agent_name in AGENTS_TO_TRAIN:
        mass_srs = {m: [] for m in OOD_MASSES}
        for seed in SEEDS:
            rk = f"{agent_name}_s{seed}"
            if rk in all_eval_results and 'ood' in all_eval_results[rk]:
                ood = all_eval_results[rk]['ood']
                for m in OOD_MASSES:
                    key = f'mass_{m}'
                    if key in ood:
                        mass_srs[m].append(ood[key]['success_rate'])

        if not any(mass_srs.values()):
            continue

        masses_plot = []
        means_plot = []
        lo_plot = []
        hi_plot = []
        for m in OOD_MASSES:
            if mass_srs[m]:
                masses_plot.append(m)
                mu = np.mean(mass_srs[m])
                se = np.std(mass_srs[m]) / max(np.sqrt(len(mass_srs[m])), 1)
                means_plot.append(mu)
                lo_plot.append(mu - 1.96 * se)
                hi_plot.append(mu + 1.96 * se)

        c = COLORS.get(agent_name, '#888888')
        ax.plot(masses_plot, means_plot, 'o-', color=c,
                label=LABELS.get(agent_name, agent_name), linewidth=2)
        ax.fill_between(masses_plot, lo_plot, hi_plot, color=c, alpha=0.15)

    ax.axvline(x=TRAIN_MASS, color='gray', linestyle='--', alpha=0.6, label=f'Train mass ({TRAIN_MASS} kg)')
    ax.set_xlabel('Box Mass (kg)')
    ax.set_ylabel('Success Rate')
    ax.set_title('OOD Generalization: Success Rate vs Box Mass')
    ax.set_xscale('log')
    ax.set_ylim(-0.05, 1.05)
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{SAVE_ROOT}/figures/ood_generalization.png", dpi=150, bbox_inches='tight')
    plt.savefig(f"{SAVE_ROOT}/figures/ood_generalization.pdf", bbox_inches='tight')
    plt.show()
    print("Saved: ood_generalization.png / .pdf")


# ========================================
# Figure 3: Learning curves
# ========================================
def plot_learning_curves():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    window = 50  # smoothing window

    for agent_name in AGENTS_TO_TRAIN:
        all_rewards = []
        all_success = []

        for seed in SEEDS:
            rk = f"{agent_name}_s{seed}"
            if rk in all_training_logs:
                log = all_training_logs[rk]
                rewards = log.get('episode_rewards', [])
                successes = log.get('episode_successes', [])
                if rewards:
                    # Smooth
                    smoothed_r = np.convolve(rewards, np.ones(window)/window, mode='valid')
                    smoothed_s = np.convolve(successes, np.ones(window)/window, mode='valid')
                    all_rewards.append(smoothed_r)
                    all_success.append(smoothed_s)

        if not all_rewards:
            continue

        # Align lengths (truncate to shortest)
        min_len = min(len(r) for r in all_rewards)
        rewards_arr = np.array([r[:min_len] for r in all_rewards])
        success_arr = np.array([s[:min_len] for s in all_success])

        x = np.arange(min_len)
        c = COLORS.get(agent_name, '#888888')
        lbl = LABELS.get(agent_name, agent_name)

        # Reward curve
        mu_r = rewards_arr.mean(axis=0)
        std_r = rewards_arr.std(axis=0)
        axes[0].plot(x, mu_r, color=c, label=lbl, linewidth=1.5)
        axes[0].fill_between(x, mu_r - std_r, mu_r + std_r, color=c, alpha=0.1)

        # Success curve
        mu_s = success_arr.mean(axis=0)
        std_s = success_arr.std(axis=0)
        axes[1].plot(x, mu_s, color=c, label=lbl, linewidth=1.5)
        axes[1].fill_between(x, mu_s - std_s, mu_s + std_s, color=c, alpha=0.1)

    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Episode Reward (smoothed)')
    axes[0].set_title('Learning Curves: Reward')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    axes[1].set_xlabel('Episode')
    axes[1].set_ylabel('Success Rate (rolling)')
    axes[1].set_title('Learning Curves: Success Rate')
    axes[1].set_ylim(-0.05, 1.05)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{SAVE_ROOT}/figures/learning_curves.png", dpi=150, bbox_inches='tight')
    plt.savefig(f"{SAVE_ROOT}/figures/learning_curves.pdf", bbox_inches='tight')
    plt.show()
    print("Saved: learning_curves.png / .pdf")


# ========================================
# Figure 4: Sample efficiency (box plot)
# ========================================
def plot_sample_efficiency():
    fig, ax = plt.subplots(figsize=(8, 5))

    data = []
    labels = []
    colors_list = []

    for agent_name in AGENTS_TO_TRAIN:
        first_successes = []
        for seed in SEEDS:
            rk = f"{agent_name}_s{seed}"
            if rk in all_training_logs:
                log = all_training_logs[rk]
                successes = log.get('episode_successes', [])
                # Find first success episode
                for ep_idx, s in enumerate(successes):
                    if s == 1:
                        first_successes.append(ep_idx)
                        break
                else:
                    first_successes.append(len(successes))  # never succeeded

        if first_successes:
            data.append(first_successes)
            labels.append(LABELS.get(agent_name, agent_name))
            colors_list.append(COLORS.get(agent_name, '#888888'))

    if data:
        bp = ax.boxplot(data, labels=labels, patch_artist=True)
        for patch, color in zip(bp['boxes'], colors_list):
            patch.set_facecolor(color)
            patch.set_alpha(0.6)

    ax.set_ylabel('Episodes to First Success')
    ax.set_title('Sample Efficiency')
    ax.tick_params(axis='x', rotation=25)
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(f"{SAVE_ROOT}/figures/sample_efficiency.png", dpi=150, bbox_inches='tight')
    plt.savefig(f"{SAVE_ROOT}/figures/sample_efficiency.pdf", bbox_inches='tight')
    plt.show()
    print("Saved: sample_efficiency.png / .pdf")


# ========================================
# Table: LaTeX ablation table
# ========================================
def generate_latex_table():
    lines = [
        r'\begin{table}[t]',
        r'\centering',
        r'\caption{Phase 1 Ablation Results (PushBox, 500K steps, 5 seeds)}',
        r'\label{tab:phase1}',
        r'\begin{tabular}{l c c c c}',
        r'\toprule',
        r'Method & SR (ID) & SR (OOD) & Reward & Params \\',
        r'\midrule',
    ]

    for agent_name in AGENTS_TO_TRAIN:
        id_srs = []
        ood_srs = []
        rews = []
        for seed in SEEDS:
            rk = f"{agent_name}_s{seed}"
            if rk in all_eval_results:
                r = all_eval_results[rk]
                id_srs.append(r['in_distribution']['success_rate'] * 100)
                ood_srs.append(r['ood_mean_sr'] * 100)
                rews.append(r['in_distribution']['mean_reward'])

        if not id_srs:
            continue

        # Get param count (create dummy model)
        _env = DummyVecEnv([make_push_box_env()])
        _m = create_agent(agent_name, _env, seed=0)
        n_params = sum(p.numel() for p in _m.policy.parameters())
        _env.close()
        del _m

        display = LABELS.get(agent_name, agent_name)
        if agent_name == 'physrobot_sv':
            display = r'\textbf{' + display + r'}'

        lines.append(
            f"{display} & "
            f"${np.mean(id_srs):.1f} \\pm {np.std(id_srs):.1f}$ & "
            f"${np.mean(ood_srs):.1f} \\pm {np.std(ood_srs):.1f}$ & "
            f"${np.mean(rews):.1f}$ & "
            f"{n_params // 1000}K \\\\"
        )

    lines.extend([
        r'\bottomrule',
        r'\end{tabular}',
        r'\end{table}',
    ])

    latex_str = '\n'.join(lines)
    with open(f"{SAVE_ROOT}/figures/ablation_table.tex", 'w') as f:
        f.write(latex_str)

    print("LaTeX table:")
    print(latex_str)
    print(f"\nSaved: ablation_table.tex")


# ========================================
# Statistical tests
# ========================================
def run_statistical_tests():
    print("\n" + "="*60)
    print("Statistical Tests (Welch's t-test, p < 0.05)")
    print("="*60)

    our_method = 'physrobot_sv'
    our_id = [all_eval_results.get(f"{our_method}_s{s}", {}).get(
        'in_distribution', {}).get('success_rate', 0) for s in SEEDS]
    our_ood = [all_eval_results.get(f"{our_method}_s{s}", {}).get(
        'ood_mean_sr', 0) for s in SEEDS]

    for agent_name in AGENTS_TO_TRAIN:
        if agent_name == our_method:
            continue
        other_id = [all_eval_results.get(f"{agent_name}_s{s}", {}).get(
            'in_distribution', {}).get('success_rate', 0) for s in SEEDS]
        other_ood = [all_eval_results.get(f"{agent_name}_s{s}", {}).get(
            'ood_mean_sr', 0) for s in SEEDS]

        # In-distribution
        t_id, p_id = stats.ttest_ind(our_id, other_id, equal_var=False)
        sig_id = '*' if p_id < 0.05 else ''

        # OOD
        t_ood, p_ood = stats.ttest_ind(our_ood, other_ood, equal_var=False)
        sig_ood = '*' if p_ood < 0.05 else ''

        print(f"PhysRobot-SV vs {LABELS.get(agent_name, agent_name):20s}  "
              f"ID: t={t_id:+.2f} p={p_id:.3f}{sig_id}  "
              f"OOD: t={t_ood:+.2f} p={p_ood:.3f}{sig_ood}")


# ---- Generate all figures ----
print("Generating figures...\n")
plot_success_bars()
plot_ood_curves()
plot_learning_curves()
plot_sample_efficiency()
generate_latex_table()
run_statistical_tests()

print(f"\nAll figures saved to: {SAVE_ROOT}/figures/")

## Cell 8: Save Everything to Drive

In [None]:
# ============================================================
# Cell 8: Save Summary + Backup to Drive
# ============================================================

import json
import os
import shutil
from datetime import datetime

# ---- Compile master summary ----
summary = {
    'experiment': 'PhysRobot Phase 1 Ablation',
    'timestamp': datetime.now().isoformat(),
    'config': {
        'seeds': SEEDS,
        'total_timesteps': TOTAL_TIMESTEPS,
        'n_envs': N_ENVS,
        'eval_episodes': EVAL_EPISODES,
        'ood_masses': OOD_MASSES,
        'agents': AGENTS_TO_TRAIN,
        'success_threshold': 0.15,
        'train_box_mass': TRAIN_MASS,
    },
    'results': {},
}

for agent_name in AGENTS_TO_TRAIN:
    agent_results = {
        'seeds': {},
        'aggregate': {},
    }
    id_srs = []
    ood_srs = []
    train_times = []

    for seed in SEEDS:
        rk = f"{agent_name}_s{seed}"

        seed_data = {}
        if rk in all_eval_results:
            seed_data['eval'] = all_eval_results[rk]
            id_srs.append(all_eval_results[rk]['in_distribution']['success_rate'])
            ood_srs.append(all_eval_results[rk]['ood_mean_sr'])
        if rk in all_training_logs:
            seed_data['train_time_s'] = all_training_logs[rk].get('training_time_s', 0)
            train_times.append(seed_data['train_time_s'])

        agent_results['seeds'][str(seed)] = seed_data

    if id_srs:
        agent_results['aggregate'] = {
            'id_sr_mean': float(np.mean(id_srs)),
            'id_sr_std': float(np.std(id_srs)),
            'ood_sr_mean': float(np.mean(ood_srs)),
            'ood_sr_std': float(np.std(ood_srs)),
            'mean_train_time_s': float(np.mean(train_times)) if train_times else 0,
        }

    summary['results'][agent_name] = agent_results

# Save summary
summary_path = f"{SAVE_ROOT}/summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"Summary saved: {summary_path}")

# ---- Print final results ----
print("\n" + "="*70)
print("PHASE 1 ABLATION -- FINAL RESULTS")
print("="*70)
print(f"{'Agent':<20} {'ID SR':>12} {'OOD SR':>12} {'Train Time':>12}")
print("-"*70)
for agent_name in AGENTS_TO_TRAIN:
    agg = summary['results'].get(agent_name, {}).get('aggregate', {})
    if agg:
        id_str = f"{agg['id_sr_mean']:.1%} +/- {agg['id_sr_std']:.1%}"
        ood_str = f"{agg['ood_sr_mean']:.1%} +/- {agg['ood_sr_std']:.1%}"
        tt = agg.get('mean_train_time_s', 0)
        tt_str = f"{tt/60:.1f} min"
        print(f"{agent_name:<20} {id_str:>12} {ood_str:>12} {tt_str:>12}")
print("="*70)

# ---- Backup to Drive (dated) ----
if IN_COLAB:
    backup_dir = f"/content/drive/MyDrive/PhysRobot/backups/{datetime.now().strftime('%Y%m%d_%H%M')}"
    os.makedirs(backup_dir, exist_ok=True)
    # Copy key files
    for fname in ['summary.json']:
        src = f"{SAVE_ROOT}/{fname}"
        if os.path.exists(src):
            shutil.copy2(src, f"{backup_dir}/{fname}")
    # Copy figures
    fig_dir = f"{SAVE_ROOT}/figures"
    if os.path.isdir(fig_dir):
        shutil.copytree(fig_dir, f"{backup_dir}/figures", dirs_exist_ok=True)
    print(f"\nBackup saved to: {backup_dir}")

# ---- List all output files ----
print(f"\nAll output files in {SAVE_ROOT}:")
for root, dirs, files in os.walk(SAVE_ROOT):
    for f in sorted(files):
        full = os.path.join(root, f)
        size = os.path.getsize(full)
        rel = os.path.relpath(full, SAVE_ROOT)
        print(f"  {rel:<60s} {size/1024:.1f} KB")

print("\n" + "="*70)
print("DONE. Ready for Phase 2 (SAC, TD3, multi-object).")
print("="*70)