In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

class missile_interception_3d(gym.Env):
    def __init__(self):
        # 1. Define Action Space (The Joystick: Left/Right, Up/Down)
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2, ), dtype=np.float32)
        
        # 2. Define Observation Space (12 original + 4 geometry features = 16)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)

        self.np_random = np.random.RandomState()
        
        # 3. Time Settings
        self.dt_act = 0.1             
        self.n_substeps = 10          
        self.dt_sim = self.dt_act / self.n_substeps 
        self.t_max = 650.0            

        # 4. Physical Limits
        self.a_max = 350.0   # Max G-force (m/s^2) ~35G
        self.da_max = 2500.0 # Jerk Limit (m/s^3)
        self.tau = 0.05      # Airframe Lag
        self.g = 9.81        
        self.collision_radius = 150.0  
        self.max_distance = 4_000_000.0 

        self.p_easy = 1.0                   
        self.range_min = 70_000.0           
        self.range_easy_max = 200_000.0     
        self.range_hard_max = 1_000_000.0   

        self.targetbox_x_min = -15000
        self.targetbox_x_max = 15000
        self.targetbox_y_min = -15000
        self.targetbox_y_max = 15000

    def generate_enemy_missile(self):
        if self.np_random.rand() < self.p_easy:
            self.range_max_used = self.range_easy_max
        else:
            self.range_max_used = self.range_hard_max

        range_min = self.range_min
        self.attack_target_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.attack_target_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)
        self.enemy_launch_angle = self.np_random.uniform(0, 2 * np.pi)
        self.enemy_theta = self.np_random.uniform(0.523599, 1.0472) 

        self.range_max_used = max(self.range_max_used, range_min + 1.0)
        lower_limit = np.sqrt((range_min * self.g) / np.sin(2 * self.enemy_theta))
        upper_limit = np.sqrt((self.range_max_used * self.g) / np.sin(2 * self.enemy_theta))
        self.enemy_initial_velocity = self.np_random.uniform(lower_limit, upper_limit)

        ground_range = (
            self.enemy_initial_velocity * np.cos(self.enemy_theta)
            * (2 * self.enemy_initial_velocity * np.sin(self.enemy_theta) / self.g)
        )

        self.enemy_launch_x = self.attack_target_x + ground_range * np.cos(self.enemy_launch_angle)
        self.enemy_launch_y = self.attack_target_y + ground_range * np.sin(self.enemy_launch_angle)
        self.enemy_z = 0
        self.enemy_x = self.enemy_launch_x
        self.enemy_y = self.enemy_launch_y
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.enemy_azimuth = (self.enemy_launch_angle + np.pi) % (2 * np.pi)

    def generate_defense_missile(self):
        self.defense_launch_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.defense_launch_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)
        dx = self.enemy_launch_x - self.defense_launch_x
        dy = self.enemy_launch_y - self.defense_launch_y
        self.defense_azimuth = np.arctan2(dy, dx)
        self.defense_theta = 0.785398
        base_velocity = 3000.0
        if hasattr(self, 'range_max_used'):
            velocity_scale = min(self.range_max_used / self.range_easy_max, 1.5)
            self.defense_initial_velocity = base_velocity * velocity_scale
        else:
            self.defense_initial_velocity = base_velocity
        self.defense_x = self.defense_launch_x
        self.defense_y = self.defense_launch_y
        self.defense_z = 0
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
        self.defense_ax = 0
        self.defense_ay = 0
        self.defense_az = 0
    
    def calculate_pronav(self):
        r = self.enemy_pos - self.defense_pos
        v = self.enemy_vel - self.defense_vel
        r_mag = np.linalg.norm(r)
        
        # 1. Standard ProNav
        omega = np.cross(r, v) / (np.dot(r, r) + 1e-9)
        vc = -np.dot(r, v) / (r_mag + 1e-9)
        
        N = 3.0 
        unit_los = r / (r_mag + 1e-9)
        a_ideal = N * vc * np.cross(omega, unit_los)
        
        # 2. Project onto Body Frame
        # Note: Environment now handles gravity compensation internally,
        # so ProNav outputs desired NET lateral accel (same semantics as PPO)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)
        
        a_right = np.dot(a_ideal, right)
        a_up    = np.dot(a_ideal, up)
        
        # 4. Normalize
        action = np.array([
            a_right / self.a_max,
            a_up / self.a_max
        ], dtype=np.float32)
        
        return np.clip(action, -1.0, 1.0)
    
    def _rate_limit_norm(self, a_cmd, a_prev, da_max, dt):
        delta = a_cmd - a_prev
        max_delta = da_max * dt
        dnorm = float(np.linalg.norm(delta))
        if dnorm <= max_delta or dnorm < 1e-9:
            return a_cmd
        return a_prev + delta * (max_delta / dnorm)
    
    def _segment_sphere_intersect(self, r0, r1, r_hit):
        dr = r1 - r0
        dr_norm_sq = float(np.dot(dr, dr))
        if dr_norm_sq < 1e-12:
            return float(np.dot(r0, r0)) <= r_hit * r_hit
        s_star = -float(np.dot(r0, dr)) / dr_norm_sq
        s_star = max(0.0, min(1.0, s_star))
        r_closest = r0 + s_star * dr
        return float(np.dot(r_closest, r_closest)) <= r_hit * r_hit
    
    def _get_obs(self):
        # Relative state
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)   # rel_pos
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64) # rel_vel

        dist = float(np.linalg.norm(r)) + 1e-6
        rhat = r / dist

        # Closing speed (positive when closing)
        v_close = -float(np.dot(r, vrel)) / dist

        # Local basis from defense velocity
        forward, right, up = self._compute_lateral_basis(self.defense_vel)

        # Current lateral accel in that basis
        a_lat = np.array([
            float(np.dot(self.a_actual, right)) / (self.a_max + 1e-9),
            float(np.dot(self.a_actual, up)) / (self.a_max + 1e-9),
        ], dtype=np.float32)

        # Normalize rel_pos / rel_vel (your original scaling)
        pos_scale = self.range_hard_max  # 1_000_000
        vel_scale = 4000.0
        rel_pos_n = (r / pos_scale).astype(np.float32)
        rel_vel_n = (vrel / vel_scale).astype(np.float32)

        # Distance + closing (your original)
        dist_n = np.clip(dist / 1_000_000.0, 0.0, 4.0).astype(np.float32)
        vclose_n = np.clip(v_close / 3000.0, -2.0, 2.0).astype(np.float32)
        dist_vclose_feat = np.array([dist_n, vclose_n], dtype=np.float32)

        # Defense own vertical state (your original)
        def_z_n = np.clip(self.defense_pos[2] / 100_000.0, -1.0, 2.0).astype(np.float32)
        def_vz_n = np.clip(self.defense_vel[2] / 3000.0, -2.0, 2.0).astype(np.float32)
        def_state_feat = np.array([def_z_n, def_vz_n], dtype=np.float32)

        # ===============================
        # NEW: 4 geometry features
        # ===============================

        # 1) LOS direction projected into body lateral axes
        los_right = float(np.dot(rhat, right))  # [-1,1]
        los_up    = float(np.dot(rhat, up))     # [-1,1]

        # 2) LOS rate omega = (r x vrel) / ||r||^2   (rad/s)
        dist2 = float(np.dot(r, r)) + 1e-9
        omega = np.cross(r, vrel) / dist2  # world frame, float64

        # Project omega into body lateral axes
        omega_right = float(np.dot(omega, right))
        omega_up    = float(np.dot(omega, up))

        # Normalize omega so values sit in a sane range
        # Typical omega ~ v/R. With v~3000:
        # R=100km -> 0.03 rad/s; R=1km -> 3 rad/s
        omega_scale = 2.0
        omega_right_n = np.clip(omega_right / omega_scale, -2.0, 2.0)
        omega_up_n    = np.clip(omega_up / omega_scale, -2.0, 2.0)

        geom_feat = np.array([los_right, los_up, omega_right_n, omega_up_n], dtype=np.float32)

        # Final obs (16D)
        obs = np.concatenate(
            [rel_pos_n, rel_vel_n, a_lat, dist_vclose_feat, def_state_feat, geom_feat],
            axis=0
        ).astype(np.float32)

        return obs

    def _compute_lateral_basis(self, velocity):
        """
        Horizon-stable basis:
          forward = along velocity
          right   = world_up x forward  (horizontal right)
          up      = forward x right     (completes orthonormal frame)
        This keeps 'up' as close to world-up as possible and avoids weird twisting.
        """
        speed = float(np.linalg.norm(velocity))
        if speed < 1.0:
            forward = np.array([1.0, 0.0, 0.0], dtype=np.float32)
        else:
            forward = (velocity / speed).astype(np.float32)

        world_up = np.array([0.0, 0.0, 1.0], dtype=np.float32)

        # right = world_up x forward
        right_raw = np.cross(world_up, forward)
        rnorm = float(np.linalg.norm(right_raw))

        # If forward is near world_up, right_raw ~ 0. Pick a consistent fallback.
        if rnorm < 1e-6:
            # Choose a fixed "north" axis in world XY and build right from that
            # This prevents random spinning when vertical.
            north = np.array([1.0, 0.0, 0.0], dtype=np.float32)
            right_raw = np.cross(north, forward)
            rnorm = float(np.linalg.norm(right_raw))
            if rnorm < 1e-6:
                north = np.array([0.0, 1.0, 0.0], dtype=np.float32)
                right_raw = np.cross(north, forward)
                rnorm = float(np.linalg.norm(right_raw))

        right = (right_raw / (rnorm + 1e-9)).astype(np.float32)

        # up = forward x right (not right x forward)
        up_raw = np.cross(forward, right)
        up = (up_raw / (float(np.linalg.norm(up_raw)) + 1e-9)).astype(np.float32)

        return forward, right, up

    def step(self, action):
        if getattr(self, "done", False):
            return self._get_obs(), 0.0, True, False, {"event": "called_step_after_done", "dist": self.relative_distances[-1]}
        
        action = np.clip(action, -1.0, 1.0).astype(np.float32)
        mag = np.linalg.norm(action)
        if mag > 1.0:
            action = action / mag

        dist_before = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        terminated = False
        truncated = False
        event = "running"
        
        for _ in range(self.n_substeps):
            dt = self.dt_sim
            enemy_pos_old = self.enemy_pos.copy()
            defense_pos_old = self.defense_pos.copy()
            
            forward, right, up = self._compute_lateral_basis(self.defense_vel)
            
            # Agent command = desired NET lateral accel (world frame)
            a_net_lat_cmd = (action[0] * self.a_max * right) + (action[1] * self.a_max * up)
            
            # Gravity (world frame)
            g_vec = np.array([0.0, 0.0, -self.g], dtype=np.float32)
            
            # Lateral component of gravity in the right/up plane
            g_lat = (np.dot(g_vec, right) * right) + (np.dot(g_vec, up) * up)
            
            # Fins must cancel lateral gravity to achieve commanded NET lateral accel
            a_fins_cmd = a_net_lat_cmd - g_lat
            
            # Apply rate limit + lag to fins acceleration
            self.a_cmd_prev = self._rate_limit_norm(a_fins_cmd, self.a_cmd_prev, self.da_max, dt)
            self.a_actual += (self.a_cmd_prev - self.a_actual) * (dt / self.tau)
            
            # Integrate translational dynamics
            self.defense_vel += (self.a_actual + g_vec) * dt
            self.defense_pos += self.defense_vel * dt
            self.defense_x, self.defense_y, self.defense_z = self.defense_pos
            
            # Enemy missile: pure ballistic (gravity only)
            self.enemy_vel += g_vec * dt
            self.enemy_pos += self.enemy_vel * dt
            self.enemy_x, self.enemy_y, self.enemy_z = self.enemy_pos
            self.t += dt
            
            r0 = enemy_pos_old - defense_pos_old
            r1 = self.enemy_pos - self.defense_pos
            if self._segment_sphere_intersect(r0, r1, self.collision_radius):
                self.success = True
                terminated = True
                self.done = True
                event = "hit"
                break
            
            dist = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
            if dist > self.max_distance:
                truncated = True
                self.done = True
                event = "diverged"
                break
            if self.defense_pos[2] < 0:
                terminated = True
                self.done = True
                event = "defense_ground"
                break
            if self.enemy_pos[2] < 0:
                terminated = True
                self.done = True
                event = "enemy_ground"
                break
            if self.t >= self.t_max:
                truncated = True
                self.done = True
                event = "timeout"
                break

        self.enemy_path.append(self.enemy_pos.copy())
        self.defense_path.append(self.defense_pos.copy())
        self.relative_distances.append(float(np.linalg.norm(self.enemy_pos - self.defense_pos)))
        self.times.append(self.t)
        
        obs = self._get_obs()
        dist_after = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        self.min_dist = min(getattr(self, "min_dist", float("inf")), dist_after)
        
        # Reward calculation (same as before)
        r_progress = (dist_before - dist_after) / 100.0
        v_scale = 1500.0
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)
        d = float(np.linalg.norm(r)) + 1e-9
        rhat = r / d
        d_dot = float(np.dot(rhat, vrel))
        r_close = np.tanh((-d_dot) / v_scale)
        reward = 1.0 * r_progress + 0.1 * r_close - 0.001
        
        if self.success:
            reward += 10000.0
        elif terminated or truncated:
            if event == "defense_ground": reward -= 5000.0
            reward -= min(2000.0, self.min_dist / 50.0)
            
        info = {
            "dist": dist_after, "event": event, "t": self.t,
            "min_dist": self.min_dist,
            "action_mag": mag, # Track how hard we are pushing
            "reward": reward
        }
        return obs, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None: self.np_random = np.random.RandomState(seed)
        self.done = False
        self.success = False
        self.t = 0.0
        self.generate_enemy_missile()
        self.generate_defense_missile()
        
        self.defense_vel = np.array([
            self.defense_initial_velocity * np.cos(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_theta)
        ], dtype=np.float32)
        
        self.enemy_vel = np.array([
            self.enemy_initial_velocity * np.cos(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_theta)
        ], dtype=np.float32)
        
        self.a_actual = np.zeros(3, dtype=np.float32)
        self.a_cmd_prev = np.zeros(3, dtype=np.float32)
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.enemy_path = [self.enemy_pos.copy()]
        self.defense_path = [self.defense_pos.copy()]
        self.relative_distances = [float(np.linalg.norm(self.enemy_pos - self.defense_pos))]
        self.times = [self.t]
        self.min_dist = float(self.relative_distances[-1])
        self.sum_r_progress = 0.0
        self.sum_r_close = 0.0
        return self._get_obs(), {}

# ==========================================
# TEST PRONAV BASELINE (No Animation)
# ==========================================

def run_baseline():
    env = missile_interception_3d()
    outcomes = []
    min_distances = []
    action_loads = [] # Track if we are saturating (maxing out fins)

    N_EPISODES = 50
    print(f"Running {N_EPISODES} episodes of Augmented ProNav...")

    for i in range(N_EPISODES):
        obs, _ = env.reset(seed=i)
        done = False
        ep_actions = []

        while not done:
            # 1. Ask ProNav for the move
            action = env.calculate_pronav()
            
            # 2. Track how hard it's pushing (0.0 to 1.0)
            mag = np.linalg.norm(action)
            ep_actions.append(mag)

            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
        
        outcomes.append(info['event'])
        min_distances.append(info['min_dist'])
        avg_load = np.mean(ep_actions)
        action_loads.append(avg_load)

        print(f"Ep {i+1:02d} | Res: {info['event']:<14} | Min Dist: {info['min_dist']:.1f} m | Avg G-Load: {avg_load*100:.1f}%")

    # Final Stats
    hits = outcomes.count("hit")
    print("\n--- SUMMARY ---")
    print(f"Hit Rate: {hits}/{N_EPISODES} ({hits/N_EPISODES*100:.1f}%)")
    print(f"Average Miss Distance (Non-hits): {np.mean([d for d, e in zip(min_distances, outcomes) if e != 'hit']):.2f} m")
    print(f"Average G-Loading: {np.mean(action_loads)*100:.1f}% (If >90%, missile is physically too weak)")

if __name__ == "__main__":
    run_baseline()

Running 50 episodes of Augmented ProNav...


NameError: name 'accel_gravity' is not defined

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from collections import Counter
from IPython.display import HTML, display

# ==========================================
# FIXED-SEED EVALUATION HARNESS
# ==========================================
EVAL_SEEDS = list(range(200))  # Fixed list of 200 seeds for consistent evaluation

def evaluate_policy(env_class, policy_fn, seeds=EVAL_SEEDS, verbose=False):
    """
    Evaluate a policy on fixed seeds for consistent comparison.
    
    Args:
        env_class: The environment class (e.g., missile_interception_3d)
        policy_fn: Function that takes (env, obs) and returns action
        seeds: List of seeds to evaluate on (default: EVAL_SEEDS)
        verbose: If True, print per-episode details
    
    Returns:
        hit_rate: Fraction of episodes that resulted in "hit"
        med_min_dist: Median minimum distance across all episodes
        counts: Counter of event types
    """
    events = []
    min_dists = []
    
    for s in seeds:
        env = env_class()
        obs, _ = env.reset(seed=s)
        done = False
        
        while not done:
            action = policy_fn(env, obs)
            obs, r, terminated, truncated, info = env.step(action)
            done = terminated or truncated
        
        event = info.get("event", "unknown")
        min_dist = info.get("min_dist", float("inf"))
        
        events.append(event)
        min_dists.append(min_dist)
        
        if verbose and (s < 5 or event != "hit"):  # Print first 5 or failures
            print(f"  Seed {s:3d}: {event:15s} | min_dist={min_dist:.1f}m")
    
    hit_rate = sum(e == "hit" for e in events) / len(events)
    med_min_dist = float(np.median(min_dists))
    counts = Counter(events)
    
    return hit_rate, med_min_dist, counts

# ==========================================
# ANIMATION FUNCTIONS (copy from training script)
# ==========================================
def update_paths(num, xe, ye, ze, xd, yd, zd, lines, ax):
    lines[0].set_data_3d(xe[:num], ye[:num], ze[:num])
    lines[1].set_data_3d(xd[:num], yd[:num], zd[:num])
    ax.view_init(elev=20, azim=-60 + (num * 0.2))
    return lines

def animate_trajectories(enemy_path, defense_path):
    enemy_path_array = np.array(enemy_path)
    defense_path_array = np.array(defense_path)

    xe, ye, ze = enemy_path_array[:, 0], enemy_path_array[:, 1], enemy_path_array[:, 2]
    xd, yd, zd = defense_path_array[:, 0], defense_path_array[:, 1], defense_path_array[:, 2]

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    all_x = np.concatenate((xe, xd))
    all_y = np.concatenate((ye, yd))
    all_z = np.concatenate((ze, zd))

    eps = 1e-6
    ax.set_xlim([np.min(all_x) - eps, np.max(all_x) + eps])
    ax.set_ylim([np.min(all_y) - eps, np.max(all_y) + eps])
    ax.set_zlim([0, max(eps, np.max(all_z) + eps)])

    line_enemy, = ax.plot([], [], [], 'b-', linewidth=2, label="Enemy (Ballistic)")
    line_defense, = ax.plot([], [], [], 'r-', linewidth=2, label="Defense (Interceptor)")

    ax.scatter(xe[0], ye[0], ze[0], color='blue', s=50, marker='o')
    ax.scatter(xd[0], yd[0], zd[0], color='red', s=50, marker='^')

    ax.set_xlabel('X (m)')
    ax.set_ylabel('Y (m)')
    ax.set_zlabel('Altitude (m)')
    ax.set_title('Interception Simulation')
    ax.legend()

    total_steps = len(xe)
    step_size = max(1, total_steps // 200)
    frames = range(0, total_steps, step_size)

    ani = animation.FuncAnimation(
        fig,
        update_paths,
        frames=len(frames),
        fargs=(xe[::step_size], ye[::step_size], ze[::step_size],  # FIXED: step_size not step_slice
               xd[::step_size], yd[::step_size], zd[::step_size],
               [line_enemy, line_defense], ax),
        interval=30,
        blit=False
    )

    plt.close(fig)
    return HTML(ani.to_html5_video())

# ==========================================
# Test ProNav Baseline (NO TRAINING)
# ==========================================
SEED = 0
NUM_EVAL_EPISODES = 100

def test_pronav_baseline():
    """Test ProNav using fixed-seed evaluation harness."""
    
    events = []
    rewards = []
    ep_lens = []
    min_dists = []
    final_dists = []
    
    # Detailed failure mode tracking
    crash_altitudes = []  # Track altitude when defense crashes
    enemy_ground_times = []  # Track when enemy hits ground
    
    for ep in range(NUM_EVAL_EPISODES):
        env = missile_interception_3d()
        obs, _ = env.reset(seed=SEED + ep)
        
        done = False
        ep_reward = 0.0
        steps = 0
        
        # Track action magnitudes to diagnose saturation
        action_mags = []
        
        while not done:
            # USE PRONAV INSTEAD OF MODEL
            action = env.calculate_pronav()
            action_mags.append(float(np.linalg.norm(action)))
            
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            ep_reward += reward
            steps += 1
            
            if steps > 10000:
                info["event"] = "timeout"
                break
        
        # Collect stats
        event = info.get("event", "unknown")
        final_dist = float(np.linalg.norm(env.enemy_pos - env.defense_pos))
        min_dist = getattr(env, "min_dist", final_dist)
        
        events.append(event)
        rewards.append(ep_reward)
        ep_lens.append(steps)
        min_dists.append(min_dist)
        final_dists.append(final_dist)
        
        # Track failure modes
        if event == "defense_ground":
            crash_altitudes.append(float(env.defense_pos[2]))
        if event == "enemy_ground":
            enemy_ground_times.append(float(env.t))
        
        # Compute avg action magnitude (saturation indicator)
        avg_action_mag = float(np.mean(action_mags)) if action_mags else 0.0
        
        # DETAILED PRINT
        print(
            f"[ProNav EP {ep+1}/{NUM_EVAL_EPISODES}] "
            f"reward={ep_reward:.2f} "
            f"len={steps} "
            f"event={event} "
            f"final_dist={final_dist:.1f}m "
            f"min_dist={min_dist:.1f}m"
        )
        
        # Print physical diagnostics for failures
        if event in ["defense_ground", "diverged", "timeout"]:
            print(
                f"  FAIL: def_z={env.defense_pos[2]:.1f}m "
                f"def_vz={env.defense_vel[2]:.1f}m/s "
                f"avg_action_mag={avg_action_mag:.3f} "
                f"time={env.t:.1f}s"
            )
        
        # Print shaping components
        if hasattr(env, 'sum_r_progress'):
            print(
                f"  shaping: r_progress={env.sum_r_progress:.2f} "
                f"r_close={env.sum_r_close:.2f}"
            )
        
        # Animate first 5 episodes
        if ep < 5:
            print(f"  [Animating episode {ep+1}...]")
            display(animate_trajectories(env.enemy_path, env.defense_path))
    
    # ==========================================
    # SUMMARY STATISTICS
    # ==========================================
    c = Counter(events)
    hit_rate = c.get("hit", 0) / NUM_EVAL_EPISODES
    
    print("\n" + "="*60)
    print(f"PRONAV BASELINE RESULTS ({NUM_EVAL_EPISODES} episodes)")
    print("="*60)
    print(f"Hit Rate: {hit_rate:.2%} ({c.get('hit', 0)}/{NUM_EVAL_EPISODES})")
    print(f"Avg Reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
    print(f"Avg Episode Length: {np.mean(ep_lens):.1f} ± {np.std(ep_lens):.1f}")
    print(f"Avg Min Distance: {np.mean(min_dists):.1f}m (median: {np.median(min_dists):.1f}m)")
    print(f"Avg Final Distance: {np.mean(final_dists):.1f}m")
    print(f"\nEvent Distribution:")
    for ev, count in c.most_common():
        print(f"  {ev}: {count} ({count/NUM_EVAL_EPISODES:.1%})")
    
    # ==========================================
    # FAILURE MODE ANALYSIS
    # ==========================================
    print(f"\n" + "="*60)
    print("FAILURE MODE ANALYSIS")
    print("="*60)
    
    if crash_altitudes:
        print(f"Defense Ground Crashes: {len(crash_altitudes)}")
        print(f"  Avg crash altitude: {np.mean(crash_altitudes):.1f}m (should be ~0)")
    
    if enemy_ground_times:
        print(f"Enemy Ground Impacts: {len(enemy_ground_times)}")
        print(f"  Avg time to enemy ground: {np.mean(enemy_ground_times):.1f}s")
    
    diverged_count = c.get("diverged", 0)
    if diverged_count > 0:
        print(f"Diverged Episodes: {diverged_count}")
        print(f"  (Defense missile flew away from target)")
    
    timeout_count = c.get("timeout", 0)
    if timeout_count > 0:
        print(f"Timeout Episodes: {timeout_count}")
        print(f"  (Exceeded {env.t_max}s time limit)")
    
    print("="*60)
    
    # ==========================================
    # FIXED-SEED EVALUATION (for consistent comparison)
    # ==========================================
    print("\n" + "="*60)
    print("FIXED-SEED EVALUATION (200 seeds)")
    print("="*60)
    
    # ProNav policy function
    def pronav_policy(env, obs):
        return env.calculate_pronav()
    
    # Evaluate on fixed seeds
    hit_rate_fixed, med_min_dist_fixed, counts_fixed = evaluate_policy(
        missile_interception_3d,
        pronav_policy,
        seeds=EVAL_SEEDS,
        verbose=True
    )
    
    print(f"\nFixed-Seed Results:")
    num_hits = counts_fixed.get('hit', 0)
    num_seeds = len(EVAL_SEEDS)
    print(f"  Hit Rate: {hit_rate_fixed:.2%} ({num_hits}/{num_seeds})")
    print(f"  Median Min Distance: {med_min_dist_fixed:.1f}m")
    print(f"  Event Distribution:")
    for ev, count in counts_fixed.most_common():
        pct = count / num_seeds
        print(f"    {ev}: {count} ({pct:.1%})")
    print("="*60)

# Run the baseline test
test_pronav_baseline()

[ProNav EP 1/100] reward=11448.97 len=398 event=hit final_dist=121.3m min_dist=121.3m
  shaping: r_progress=0.00 r_close=0.00
  [Animating episode 1...]


[ProNav EP 2/100] reward=10670.13 len=189 event=hit final_dist=122.0m min_dist=122.0m
  shaping: r_progress=0.00 r_close=0.00
  [Animating episode 2...]


[ProNav EP 3/100] reward=11126.99 len=319 event=hit final_dist=128.8m min_dist=128.8m
  shaping: r_progress=0.00 r_close=0.00
  [Animating episode 3...]


[ProNav EP 4/100] reward=11704.38 len=493 event=hit final_dist=140.1m min_dist=140.1m
  shaping: r_progress=0.00 r_close=0.00
  [Animating episode 4...]


[ProNav EP 5/100] reward=10715.12 len=212 event=hit final_dist=125.3m min_dist=125.3m
  shaping: r_progress=0.00 r_close=0.00
  [Animating episode 5...]


[ProNav EP 6/100] reward=11530.64 len=426 event=hit final_dist=137.7m min_dist=137.7m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 7/100] reward=11408.27 len=371 event=hit final_dist=137.5m min_dist=137.5m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 8/100] reward=11230.54 len=367 event=hit final_dist=138.5m min_dist=138.5m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 9/100] reward=10550.27 len=158 event=hit final_dist=133.4m min_dist=133.4m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 10/100] reward=11025.79 len=284 event=hit final_dist=145.5m min_dist=145.5m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 11/100] reward=11000.35 len=288 event=hit final_dist=148.3m min_dist=148.3m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 12/100] reward=11288.10 len=358 event=hit final_dist=120.1m min_dist=120.1m
  shaping: r_progress=0.00 r_close=0.00
[ProNav EP 13/100] reward=11950.13 len=487 event=hit final_dist=118.7m min_dist=118.7m
  shaping: r_progress=0.00 r