In [16]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

class missile_interception_3d(gym.Env):
    def __init__(self):
        # 1. Define Action Space (The Joystick: Left/Right, Up/Down)
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2, ), dtype=np.float32)
        
        # 2. Define Observation Space (20D ego-frame version)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(20,), dtype=np.float32)

        self.np_random = np.random.RandomState()
        
        # 3. Time Settings
        self.dt_act = 0.1             
        self.n_substeps = 10          
        self.dt_sim = self.dt_act / self.n_substeps 
        self.t_max = 650.0            

        # 4. Physical Limits
        self.a_max = 350.0   # Max G-force (m/s^2) ~35G
        self.da_max = 2500.0 # Jerk Limit (m/s^3)
        self.tau = 0.05      # Airframe Lag
        self.g = 9.81        
        self.collision_radius = 150.0  
        self.max_distance = 4_000_000.0 

        self.p_easy = 1.0                   
        self.range_min = 70_000.0           
        self.range_easy_max = 200_000.0     
        self.range_hard_max = 1_000_000.0   

        self.targetbox_x_min = -15000
        self.targetbox_x_max = 15000
        self.targetbox_y_min = -15000
        self.targetbox_y_max = 15000

        # ----------------------------
        # Potential-based shaping (ZEM_perp)
        # IMPORTANT: gamma_shape MUST match PPO's gamma
        # ----------------------------
        self.gamma_shape = 0.99      # set to your PPO gamma
        self.w_zem = 1.0             # shaping weight
        self.zem_scale = 50_000.0    # meters; tunes magnitude of phi
        self.tgo_max = 15.0          # seconds; clamp lookahead when closing
        self.tgo_fixed = 3.0         # seconds; lookahead when NOT closing
        self.vc_min = 1.0            # m/s; treat <= this as "not closing"

    def generate_enemy_missile(self):
        if self.np_random.rand() < self.p_easy:
            self.range_max_used = self.range_easy_max
        else:
            self.range_max_used = self.range_hard_max

        range_min = self.range_min
        self.attack_target_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.attack_target_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)
        self.enemy_launch_angle = self.np_random.uniform(0, 2 * np.pi)
        self.enemy_theta = self.np_random.uniform(0.523599, 1.0472) 

        self.range_max_used = max(self.range_max_used, range_min + 1.0)
        lower_limit = np.sqrt((range_min * self.g) / np.sin(2 * self.enemy_theta))
        upper_limit = np.sqrt((self.range_max_used * self.g) / np.sin(2 * self.enemy_theta))
        self.enemy_initial_velocity = self.np_random.uniform(lower_limit, upper_limit)

        ground_range = (
            self.enemy_initial_velocity * np.cos(self.enemy_theta)
            * (2 * self.enemy_initial_velocity * np.sin(self.enemy_theta) / self.g)
        )

        self.enemy_launch_x = self.attack_target_x + ground_range * np.cos(self.enemy_launch_angle)
        self.enemy_launch_y = self.attack_target_y + ground_range * np.sin(self.enemy_launch_angle)
        self.enemy_z = 0
        self.enemy_x = self.enemy_launch_x
        self.enemy_y = self.enemy_launch_y
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.enemy_azimuth = (self.enemy_launch_angle + np.pi) % (2 * np.pi)

    def generate_defense_missile(self):
        self.defense_launch_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.defense_launch_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)

        dx = self.enemy_launch_x - self.defense_launch_x
        dy = self.enemy_launch_y - self.defense_launch_y
        az_nominal = np.arctan2(dy, dx)

        # --- Misalignment (domain randomization of initial heading) ---
        # Mixture: most episodes small error, some episodes large az error
        p_misaligned = 0.35  # 35% "hard" starts
        if self.np_random.rand() < p_misaligned:
            # Hard: big azimuth error → strong RIGHT required
            az_noise = self.np_random.uniform(-np.deg2rad(60.0), np.deg2rad(60.0))
        else:
            # Easy: small azimuth error → gentle correction
            az_noise = self.np_random.uniform(-np.deg2rad(10.0), np.deg2rad(10.0))

        self.defense_azimuth = az_nominal + az_noise

        # Elevation noise: avoid always same vertical plane
        theta_nominal = 0.785398  # ~45 deg
        theta_noise_deg = 10.0
        theta_noise = self.np_random.uniform(
            -np.deg2rad(theta_noise_deg),
            +np.deg2rad(theta_noise_deg),
        )
        self.defense_theta = float(np.clip(theta_nominal + theta_noise,
                                           np.deg2rad(10.0),
                                           np.deg2rad(80.0)))
        # ---------------------------------------------------------------

        base_velocity = 3000.0
        if hasattr(self, 'range_max_used'):
            velocity_scale = min(self.range_max_used / self.range_easy_max, 1.5)
            self.defense_initial_velocity = base_velocity * velocity_scale
        else:
            self.defense_initial_velocity = base_velocity

        self.defense_x = self.defense_launch_x
        self.defense_y = self.defense_launch_y
        self.defense_z = 0.0
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)

        self.defense_ax = 0.0
        self.defense_ay = 0.0
        self.defense_az = 0.0
    
    def _smoothstep(self, x: float) -> float:
        """Smooth ramp 0->1 with zero slope at ends, clamps outside [0,1]"""
        x = float(np.clip(x, 0.0, 1.0))
        return x * x * (3.0 - 2.0 * x)
    
    def calculate_pronav(self):
        eps = 1e-9

        # Relative geometry (use float64 for stability)
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        v = self.defense_vel.astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)

        R = float(np.linalg.norm(r)) + eps
        V = float(np.linalg.norm(v)) + eps

        rhat = r / R
        vhat = v / V

        # Heading error alpha = angle between velocity direction and LOS direction
        cosang = float(np.clip(np.dot(vhat, rhat), -1.0, 1.0))
        alpha = float(np.arccos(cosang))  # radians

        # LOS angular rate omega (world frame)
        omega = np.cross(r, vrel) / (float(np.dot(r, r)) + eps)
        omega_mag = float(np.linalg.norm(omega))

        # Closing speed (positive => closing)
        vc = -float(np.dot(r, vrel)) / R

        # --- PN term ---
        N = 3.0
        a_pn = N * vc * np.cross(omega, rhat)  # lateral accel in world frame

        # --- Acquisition term (turn-to-LOS) ---
        # Perpendicular component of LOS relative to forward direction
        rhat_perp = rhat - float(np.dot(rhat, vhat)) * vhat
        nperp = float(np.linalg.norm(rhat_perp))

        if nperp < 1e-8:
            a_acq = np.zeros(3, dtype=np.float64)
        else:
            rhat_perp /= nperp  # unit sideways "turn toward LOS" direction

            # Curvature-based magnitude: ~k * V^2 / R, saturate later via a_max
            k_acq = 5.0  # try 3.0–8.0
            a_acq = k_acq * (V * V / R) * rhat_perp

        # --- Blend weight w: 0 => pure PN, 1 => pure acquisition ---

        # Alpha-based weight (dominant)
        alpha_on   = np.deg2rad(20.0)   # start blending earlier
        alpha_full = np.deg2rad(55.0)

        x_alpha = (alpha - alpha_on) / (alpha_full - alpha_on + eps)
        w_alpha = self._smoothstep(x_alpha)

        # Omega-based modifier (only boosts acquisition when PN is sleepy)
        omega_full = 0.00
        omega_on   = 0.05   # <-- key: less brittle than 0.02

        x_omega = (omega_on - omega_mag) / (omega_on - omega_full + eps)
        w_omega = self._smoothstep(x_omega)

        # Robust combine: alpha dominates; omega can't fully shut it off
        w = w_alpha * (0.25 + 0.75 * w_omega)

        # Optional: if not closing, force strong acquisition
        if vc <= 0.0:
            w = max(w, 0.9)

        a_ideal = (1.0 - w) * a_pn + w * a_acq

        # Project into your lateral control basis (right/up) and normalize by a_max
        # Note: Environment now handles gravity compensation internally,
        # so ProNav outputs desired NET lateral accel (same semantics as PPO)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)
        a_right = float(np.dot(a_ideal, right))
        a_up    = float(np.dot(a_ideal, up))

        action = np.array([a_right / self.a_max, a_up / self.a_max], dtype=np.float32)
        return np.clip(action, -1.0, 1.0)
    
    def _rate_limit_norm(self, a_cmd, a_prev, da_max, dt):
        delta = a_cmd - a_prev
        max_delta = da_max * dt
        dnorm = float(np.linalg.norm(delta))
        if dnorm <= max_delta or dnorm < 1e-9:
            return a_cmd
        return a_prev + delta * (max_delta / dnorm)
    
    def _segment_sphere_intersect(self, r0, r1, r_hit):
        dr = r1 - r0
        dr_norm_sq = float(np.dot(dr, dr))
        if dr_norm_sq < 1e-12:
            return float(np.dot(r0, r0)) <= r_hit * r_hit
        s_star = -float(np.dot(r0, dr)) / dr_norm_sq
        s_star = max(0.0, min(1.0, s_star))
        r_closest = r0 + s_star * dr
        return float(np.dot(r_closest, r_closest)) <= r_hit * r_hit
    
    def _phi_zem_perp(self):
        """
        Potential for potential-based shaping:
          Phi(s) = - ||ZEM_perp|| / zem_scale

        Returns:
          phi, zem_perp_norm, Vc, tgo
        """
        eps = 1e-9

        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)

        R = float(np.linalg.norm(r)) + eps
        rhat = r / R

        # Closing speed (positive = closing)
        Vc = -float(np.dot(rhat, vrel))

        if Vc > self.vc_min:
            tgo = R / max(Vc, eps)
            tgo = float(np.clip(tgo, 0.0, self.tgo_max))
        else:
            tgo = float(self.tgo_fixed)

        zem = r + vrel * tgo
        zem_perp = zem - float(np.dot(zem, rhat)) * rhat
        zem_perp_norm = float(np.linalg.norm(zem_perp))

        phi = -zem_perp_norm / (float(self.zem_scale) + eps)
        return float(phi), zem_perp_norm, float(Vc), tgo
    
    
    def _get_obs(self):
        eps = 1e-9

        # World-frame relative state
        r_world = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel_world = (self.enemy_vel - self.defense_vel).astype(np.float64)

        # Local basis from defense velocity (world frame unit vectors)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)

        # ===============================
        # 1) Ego-frame (body-frame) r and vrel
        # ===============================
        r_body = np.array([
            float(np.dot(r_world, forward)),
            float(np.dot(r_world, right)),
            float(np.dot(r_world, up)),
        ], dtype=np.float64)

        vrel_body = np.array([
            float(np.dot(vrel_world, forward)),
            float(np.dot(vrel_world, right)),
            float(np.dot(vrel_world, up)),
        ], dtype=np.float64)

        # Normalize r_body / vrel_body (keep your original scaling)
        pos_scale = float(self.range_hard_max)   # 1_000_000
        vel_scale = 4000.0

        r_body_n = (r_body / (pos_scale + eps)).astype(np.float32)
        vrel_body_n = (vrel_body / (vel_scale + eps)).astype(np.float32)

        # ===============================
        # 2) Actuator state in the same action frame
        # ===============================
        a_lat = np.array([
            float(np.dot(self.a_actual, right)) / (self.a_max + eps),
            float(np.dot(self.a_actual, up)) / (self.a_max + eps),
        ], dtype=np.float32)

        # NEW: hidden actuator state that affects transitions
        a_cmd_prev_lat = np.array([
            float(np.dot(self.a_cmd_prev, right)) / (self.a_max + eps),
            float(np.dot(self.a_cmd_prev, up)) / (self.a_max + eps),
        ], dtype=np.float32)

        # ===============================
        # 3) Scalar helpers (kept)
        # ===============================
        dist = float(np.linalg.norm(r_world)) + 1e-6
        v_close = -float(np.dot(r_world, vrel_world)) / dist  # positive when closing

        dist_n = np.float32(np.clip(dist / 1_000_000.0, 0.0, 4.0))
        vclose_n = np.float32(np.clip(v_close / 3000.0, -2.0, 2.0))
        dist_vclose_feat = np.array([dist_n, vclose_n], dtype=np.float32)

        # Defense own vertical state (keep for ground constraint)
        def_z_n = np.float32(np.clip(self.defense_pos[2] / 100_000.0, -1.0, 2.0))
        def_vz_n = np.float32(np.clip(self.defense_vel[2] / 3000.0, -2.0, 2.0))
        def_state_feat = np.array([def_z_n, def_vz_n], dtype=np.float32)

        # ===============================
        # 4) Keep your geometry features (consistent with ego-frame)
        # ===============================
        dist_body = float(np.linalg.norm(r_body)) + 1e-6

        # LOS lateral projections in body frame
        los_right = float(r_body[1] / dist_body)
        los_up    = float(r_body[2] / dist_body)

        # LOS rate omega in body frame: omega = (r x vrel)/||r||^2
        dist2_body = float(np.dot(r_body, r_body)) + eps
        omega_body = np.cross(r_body, vrel_body) / dist2_body

        omega_right = float(omega_body[1])
        omega_up    = float(omega_body[2])

        omega_scale = 2.0
        omega_right_n = float(np.clip(omega_right / omega_scale, -2.0, 2.0))
        omega_up_n    = float(np.clip(omega_up / omega_scale, -2.0, 2.0))

        geom_feat = np.array([los_right, los_up, omega_right_n, omega_up_n], dtype=np.float32)

        # ===============================
        # 5) NEW: kinematics garnish
        # ===============================
        V_def = float(np.linalg.norm(self.defense_vel))
        V_def_n = np.float32(np.clip(V_def / 3000.0, 0.0, 3.0))  # scale: 3000 m/s baseline
        forward_z = np.float32(float(forward[2]))               # dot(forward, world_up) since world_up=[0,0,1]

        kin_feat = np.array([V_def_n, forward_z], dtype=np.float32)

        # Final obs (20D)
        obs = np.concatenate(
            [r_body_n, vrel_body_n, a_lat, a_cmd_prev_lat, dist_vclose_feat, def_state_feat, geom_feat, kin_feat],
            axis=0
        ).astype(np.float32)

        # Optional sanity check while iterating
        # assert obs.shape == (20,), obs.shape

        return obs

    def _compute_lateral_basis(self, velocity):
        """
        Horizon-stable basis:
          forward = along velocity
          right   = world_up x forward  (horizontal right)
          up      = forward x right     (completes orthonormal frame)
        This keeps 'up' as close to world-up as possible and avoids weird twisting.
        """
        speed = float(np.linalg.norm(velocity))
        if speed < 1.0:
            forward = np.array([1.0, 0.0, 0.0], dtype=np.float32)
        else:
            forward = (velocity / speed).astype(np.float32)

        world_up = np.array([0.0, 0.0, 1.0], dtype=np.float32)

        # right = world_up x forward
        right_raw = np.cross(world_up, forward)
        rnorm = float(np.linalg.norm(right_raw))

        # If forward is near world_up, right_raw ~ 0. Pick a consistent fallback.
        if rnorm < 1e-6:
            # Choose a fixed "north" axis in world XY and build right from that
            # This prevents random spinning when vertical.
            north = np.array([1.0, 0.0, 0.0], dtype=np.float32)
            right_raw = np.cross(north, forward)
            rnorm = float(np.linalg.norm(right_raw))
            if rnorm < 1e-6:
                north = np.array([0.0, 1.0, 0.0], dtype=np.float32)
                right_raw = np.cross(north, forward)
                rnorm = float(np.linalg.norm(right_raw))

        right = (right_raw / (rnorm + 1e-9)).astype(np.float32)

        # up = forward x right (not right x forward)
        up_raw = np.cross(forward, right)
        up = (up_raw / (float(np.linalg.norm(up_raw)) + 1e-9)).astype(np.float32)

        return forward, right, up

    def step(self, action):
        if getattr(self, "done", False):
            return self._get_obs(), 0.0, True, False, {"event": "called_step_after_done", "dist": self.relative_distances[-1]}
        
        action = np.clip(action, -1.0, 1.0).astype(np.float32)
        mag = float(np.linalg.norm(action))
        if mag > 1.0:
            action = action / mag
            mag = 1.0
        
        # Update episode trackers
        self.ep_max_action_mag = max(self.ep_max_action_mag, float(mag))
        self.ep_max_accel = max(self.ep_max_accel, float(np.linalg.norm(self.a_actual)))

        dist_before = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        
        # --- Shaping: compute phi BEFORE transition ---
        phi_before, zem_perp_before, Vc_before, tgo_before = self._phi_zem_perp()
        terminated = False
        truncated = False
        event = "running"
        
        for _ in range(self.n_substeps):
            dt = self.dt_sim
            enemy_pos_old = self.enemy_pos.copy()
            defense_pos_old = self.defense_pos.copy()
            
            forward, right, up = self._compute_lateral_basis(self.defense_vel)
            
            # Agent command = desired NET lateral accel (world frame)
            a_net_lat_cmd = (action[0] * self.a_max * right) + (action[1] * self.a_max * up)
            
            # Gravity (world frame)
            g_vec = np.array([0.0, 0.0, -self.g], dtype=np.float32)
            
            # Lateral component of gravity in the right/up plane
            g_lat = (np.dot(g_vec, right) * right) + (np.dot(g_vec, up) * up)
            
            # Fins must cancel lateral gravity to achieve commanded NET lateral accel
            a_fins_cmd = a_net_lat_cmd - g_lat
            
            # Apply rate limit + lag to fins acceleration
            self.a_cmd_prev = self._rate_limit_norm(a_fins_cmd, self.a_cmd_prev, self.da_max, dt)
            self.a_actual += (self.a_cmd_prev - self.a_actual) * (dt / self.tau)
            
            # Integrate translational dynamics
            self.defense_vel += (self.a_actual + g_vec) * dt
            self.defense_pos += self.defense_vel * dt
            self.defense_x, self.defense_y, self.defense_z = self.defense_pos
            
            # Enemy missile: pure ballistic (gravity only)
            self.enemy_vel += g_vec * dt
            self.enemy_pos += self.enemy_vel * dt
            self.enemy_x, self.enemy_y, self.enemy_z = self.enemy_pos
            self.t += dt
            
            r0 = enemy_pos_old - defense_pos_old
            r1 = self.enemy_pos - self.defense_pos
            if self._segment_sphere_intersect(r0, r1, self.collision_radius):
                self.success = True
                terminated = True
                self.done = True
                event = "hit"
                self.time_to_hit = float(self.t)
                self.terminal_event = "hit"
                break
            
            dist = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
            if dist > self.max_distance:
                truncated = True
                self.done = True
                event = "diverged"
                self.terminal_event = "diverged"
                break
            if self.defense_pos[2] < 0:
                terminated = True
                self.done = True
                event = "defense_ground"
                self.terminal_event = "defense_ground"
                break
            if self.enemy_pos[2] < 0:
                terminated = True
                self.done = True
                event = "enemy_ground"
                self.terminal_event = "enemy_ground"
                break
            if self.t >= self.t_max:
                truncated = True
                self.done = True
                event = "timeout"
                self.terminal_event = "timeout"
                break

        self.enemy_path.append(self.enemy_pos.copy())
        self.defense_path.append(self.defense_pos.copy())
        self.relative_distances.append(float(np.linalg.norm(self.enemy_pos - self.defense_pos)))
        self.times.append(self.t)
        
        obs = self._get_obs()
        dist_after = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        self.min_dist = min(getattr(self, "min_dist", float("inf")), dist_after)
        self.ep_min_dist = min(self.ep_min_dist, float(dist_after))
        
        # Reward calculation (same as before)
        r_progress = (dist_before - dist_after) / 100.0
        v_scale = 1500.0
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)
        d = float(np.linalg.norm(r)) + 1e-9
        rhat = r / d
        d_dot = float(np.dot(rhat, vrel))
        r_close = np.tanh((-d_dot) / v_scale)
        
        # --- Shaping: compute phi AFTER transition ---
        if terminated or truncated:
            # standard trick for episodic shaping: set terminal potential to 0
            phi_after = 0.0
            zem_perp_after = None
            Vc_after = None
            tgo_after = None
        else:
            phi_after, zem_perp_after, Vc_after, tgo_after = self._phi_zem_perp()

        r_zem = self.w_zem * (self.gamma_shape * phi_after - phi_before)
        
        # Accumulate shaping rewards for debugging
        self.sum_r_progress += float(r_progress)
        self.sum_r_close += float(r_close)
        self.sum_r_zem += float(r_zem)
        
        # Reward breakdown (named components)
        step_penalty = -0.001
        terminal_bonus = 0.0
        terminal_penalty = 0.0
        
        if self.success:
            terminal_bonus = 10000.0
        elif terminated or truncated:
            if event == "defense_ground":
                terminal_penalty += 5000.0
            terminal_penalty += min(2000.0, self.ep_min_dist / 50.0)
        
        reward = (1.0 * r_progress) + (0.1 * r_close) + step_penalty + terminal_bonus - terminal_penalty + r_zem
        
        info = {
            "event": event,
            "t": float(self.t),
            
            # reward pieces (training-related)
            "reward_terms": {
                "r_progress": float(r_progress),
                "r_close": float(r_close),
                "step_penalty": float(step_penalty),
                "terminal_bonus": float(terminal_bonus),
                "terminal_penalty": float(-terminal_penalty),  # negative contribution
                "r_zem": float(r_zem),
            },
            "reward": float(reward),
            
            # ZEM shaping debug info
            "zem_debug": {
                "phi_before": float(phi_before),
                "phi_after": float(phi_after),
                "zem_perp_before": float(zem_perp_before),
                "zem_perp_after": None if zem_perp_after is None else float(zem_perp_after),
                "Vc_before": float(Vc_before),
                "Vc_after": None if Vc_after is None else float(Vc_after),
                "tgo_before": float(tgo_before),
                "tgo_after": None if tgo_after is None else float(tgo_after),
            },
            
            # eval snapshots (NOT the full episode metrics yet)
            "eval_step": {
                "dist": float(dist_after),
                "action_mag": float(mag),
                "accel_norm": float(np.linalg.norm(self.a_actual)),
            },
        }
        return obs, reward, terminated, truncated, info


    def _snapshot(self):
        """Copy everything that affects future transitions + trackers used by reward."""
        return {
            "t": float(self.t),
            "enemy_pos": self.enemy_pos.copy(),
            "enemy_vel": self.enemy_vel.copy(),
            "defense_pos": self.defense_pos.copy(),
            "defense_vel": self.defense_vel.copy(),
            "a_actual": self.a_actual.copy(),
            "a_cmd_prev": self.a_cmd_prev.copy(),
            "done": bool(getattr(self, "done", False)),
            "success": bool(getattr(self, "success", False)),

            # reward-critical trackers
            "min_dist": float(getattr(self, "min_dist", float("inf"))),
            "ep_min_dist": float(getattr(self, "ep_min_dist", float("inf"))),

            # episode trackers (not strictly needed for the local test but cheap)
            "ep_max_action_mag": float(getattr(self, "ep_max_action_mag", 0.0)),
            "ep_max_accel": float(getattr(self, "ep_max_accel", 0.0)),
            "time_to_hit": getattr(self, "time_to_hit", None),
            "terminal_event": getattr(self, "terminal_event", "running"),

            # log lengths (so restore can truncate)
            "enemy_path_len": len(getattr(self, "enemy_path", [])),
            "defense_path_len": len(getattr(self, "defense_path", [])),
            "relative_distances_len": len(getattr(self, "relative_distances", [])),
            "times_len": len(getattr(self, "times", [])),
        }

    def _restore(self, snap):
        """Restore environment state from snapshot."""
        self.t = float(snap["t"])
        self.enemy_pos = snap["enemy_pos"].copy()
        self.enemy_vel = snap["enemy_vel"].copy()
        self.defense_pos = snap["defense_pos"].copy()
        self.defense_vel = snap["defense_vel"].copy()
        self.a_actual = snap["a_actual"].copy()
        self.a_cmd_prev = snap["a_cmd_prev"].copy()
        self.done = bool(snap["done"])
        self.success = bool(snap["success"])

        self.min_dist = float(snap["min_dist"])
        self.ep_min_dist = float(snap["ep_min_dist"])

        self.ep_max_action_mag = float(snap["ep_max_action_mag"])
        self.ep_max_accel = float(snap["ep_max_accel"])
        self.time_to_hit = snap["time_to_hit"]
        self.terminal_event = snap["terminal_event"]

        # truncate logs (avoid memory blow + keep things consistent)
        if hasattr(self, "enemy_path"):
            self.enemy_path = self.enemy_path[: snap["enemy_path_len"]]
        if hasattr(self, "defense_path"):
            self.defense_path = self.defense_path[: snap["defense_path_len"]]
        if hasattr(self, "relative_distances"):
            self.relative_distances = self.relative_distances[: snap["relative_distances_len"]]
        if hasattr(self, "times"):
            self.times = self.times[: snap["times_len"]]


    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None: self.np_random = np.random.RandomState(seed)
        self.done = False
        self.success = False
        self.t = 0.0
        self.generate_enemy_missile()
        self.generate_defense_missile()
        
        self.defense_vel = np.array([
            self.defense_initial_velocity * np.cos(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_theta)
        ], dtype=np.float32)
        
        self.enemy_vel = np.array([
            self.enemy_initial_velocity * np.cos(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_theta)
        ], dtype=np.float32)
        
        self.a_actual = np.zeros(3, dtype=np.float32)
        self.a_cmd_prev = np.zeros(3, dtype=np.float32)
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.enemy_path = [self.enemy_pos.copy()]
        self.defense_path = [self.defense_pos.copy()]
        self.relative_distances = [float(np.linalg.norm(self.enemy_pos - self.defense_pos))]
        self.times = [self.t]
        self.min_dist = float(self.relative_distances[-1])
        self.sum_r_progress = 0.0
        self.sum_r_close = 0.0
        self.sum_r_zem = 0.0
        
        # --- episode eval trackers (NOT used in reward) ---
        self.ep_min_dist = float("inf")
        self.ep_max_action_mag = 0.0
        self.ep_max_accel = 0.0          # optional: actual accel norm
        self.time_to_hit = None
        self.terminal_event = "running"
        
        return self._get_obs(), {}


# ==========================================
# EVALUATION FUNCTION (separate from training reward)
# ==========================================
def evaluate_policy(env, policy_fn, n_episodes=100, seed0=0):
    """
    Evaluate a policy and return episode-level metrics.
    This is separate from training reward - these metrics are what you actually care about.
    
    Args:
        env: missile_interception_3d environment instance
        policy_fn: Function that takes obs and returns action
        n_episodes: Number of episodes to evaluate
        seed0: Starting seed (episodes use seed0, seed0+1, ..., seed0+n_episodes-1)
    
    Returns:
        summary: Dict with aggregated metrics (hit_rate, min_dist stats, etc.)
        metrics: Dict with raw episode data
    """
    metrics = {
        "hits": 0,
        "ground_defense": 0,
        "ground_enemy": 0,
        "diverged": 0,
        "timeout": 0,
        "min_dist_list": [],
        "time_to_hit_list": [],
        "max_g_list": [],
    }

    for i in range(n_episodes):
        obs, _ = env.reset(seed=seed0 + i)
        done = False

        while not done:
            action = policy_fn(obs)  # your PPO policy OR env.calculate_pronav() baseline
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

        event = info["event"]
        metrics["min_dist_list"].append(env.ep_min_dist)
        metrics["max_g_list"].append(env.ep_max_action_mag)

        if event == "hit":
            metrics["hits"] += 1
            if env.time_to_hit is not None:
                metrics["time_to_hit_list"].append(env.time_to_hit)
        elif event == "defense_ground":
            metrics["ground_defense"] += 1
        elif event == "enemy_ground":
            metrics["ground_enemy"] += 1
        elif event == "diverged":
            metrics["diverged"] += 1
        elif event == "timeout":
            metrics["timeout"] += 1

    hit_rate = metrics["hits"] / n_episodes

    summary = {
        "hit_rate": hit_rate,
        "min_dist_mean": float(np.mean(metrics["min_dist_list"])),
        "min_dist_p50": float(np.median(metrics["min_dist_list"])),
        "min_dist_p10": float(np.percentile(metrics["min_dist_list"], 10)),
        "time_to_hit_mean": float(np.mean(metrics["time_to_hit_list"])) if metrics["time_to_hit_list"] else None,
        "max_g_mean": float(np.mean(metrics["max_g_list"])),
        "violations": {
            "defense_ground": metrics["ground_defense"],
            "enemy_ground": metrics["ground_enemy"],
            "diverged": metrics["diverged"],
            "timeout": metrics["timeout"],
        }
    }
    return summary, metrics


# ==========================================
# TEST PRONAV BASELINE (No Animation)
# ==========================================

# def run_baseline():
#     env = missile_interception_3d()
#     outcomes = []
#     min_distances = []
#     action_loads = [] # Track if we are saturating (maxing out fins)

#     N_EPISODES = 50
#     print(f"Running {N_EPISODES} episodes of Augmented ProNav...")

#     for i in range(N_EPISODES):
#         obs, _ = env.reset(seed=i)
#         done = False
#         ep_actions = []

#         while not done:
#             # 1. Ask ProNav for the move
#             action = env.calculate_pronav()
            
#             # 2. Track how hard it's pushing (0.0 to 1.0)
#             mag = np.linalg.norm(action)
#             ep_actions.append(mag)

#             obs, reward, terminated, truncated, info = env.step(action)
#             done = terminated or truncated
        
#         outcomes.append(info['event'])
#         min_distances.append(info['min_dist'])
#         avg_load = np.mean(ep_actions)
#         action_loads.append(avg_load)

#         print(f"Ep {i+1:02d} | Res: {info['event']:<14} | Min Dist: {info['min_dist']:.1f} m | Avg G-Load: {avg_load*100:.1f}%")

#     # Final Stats
#     hits = outcomes.count("hit")
#     print("\n--- SUMMARY ---")
#     print(f"Hit Rate: {hits}/{N_EPISODES} ({hits/N_EPISODES*100:.1f}%)")
#     print(f"Average Miss Distance (Non-hits): {np.mean([d for d, e in zip(min_distances, outcomes) if e != 'hit']):.2f} m")
#     print(f"Average G-Loading: {np.mean(action_loads)*100:.1f}% (If >90%, missile is physically too weak)")

# if __name__ == "__main__":
#     run_baseline()

In [17]:
import numpy as np
from collections import defaultdict

# -------------------------
# Policies
# -------------------------

def policy_teacher(env, obs):
    return env.calculate_pronav()

def policy_do_nothing(env, obs):
    return np.zeros(2, dtype=np.float32)

def policy_random(env, obs):
    # uniform random in [-1,1], then renorm to unit ball like env does
    a = env.np_random.uniform(-1.0, 1.0, size=(2,)).astype(np.float32)
    mag = float(np.linalg.norm(a))
    if mag > 1.0:
        a = a / mag
    return a

def policy_anti_teacher(env, obs):
    return -env.calculate_pronav()

class StudentPolicy:
    """
    Plug in your current student here.
    Supported:
      - SB3 model with .predict(obs, deterministic=True)
      - PyTorch model returning action tensor
      - or you can replace __call__ with your own logic
    """
    def __init__(self, model=None, deterministic=True):
        self.model = model
        self.deterministic = deterministic

    def __call__(self, env, obs):
        if self.model is None:
            # fallback: do nothing if you haven't loaded a student yet
            return np.zeros(2, dtype=np.float32)

        # Stable-Baselines3 style
        if hasattr(self.model, "predict"):
            action, _ = self.model.predict(obs, deterministic=self.deterministic)
            return np.asarray(action, dtype=np.float32)

        # Torch style (very generic)
        try:
            import torch
            with torch.no_grad():
                x = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
                a = self.model(x).squeeze(0).cpu().numpy().astype(np.float32)
            return a
        except Exception as e:
            raise RuntimeError(f"Student model interface not recognized: {e}")


# -------------------------
# Rollout / logging
# -------------------------

def run_episode(env, policy_fn, seed=None, max_steps=20000):
    """
    Runs one episode, returns:
      - total_return
      - terminal event
      - eval metrics (min_dist, time_to_hit, max_g, violations)
      - reward component sums
    """
    obs, _ = env.reset(seed=seed)
    done = False
    steps = 0

    total_return = 0.0
    # sum reward components over the episode
    comp_sums = defaultdict(float)

    while not done and steps < max_steps:
        action = policy_fn(env, obs)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        steps += 1

        total_return += float(reward)

        # reward breakdown is in info["reward_terms"]
        rt = info.get("reward_terms", {})
        for k, v in rt.items():
            comp_sums[k] += float(v)

    event = info.get("event", "unknown")

    # Episode-level eval metrics come from env trackers
    ep_eval = {
        "hit": (event == "hit"),
        "min_dist": float(getattr(env, "ep_min_dist", np.nan)),
        "time_to_hit": float(env.time_to_hit) if getattr(env, "time_to_hit", None) is not None else None,
        "max_g": float(getattr(env, "ep_max_action_mag", np.nan)),
        "ground_defense": int(event == "defense_ground"),
        "ground_enemy": int(event == "enemy_ground"),
        "diverged": int(event == "diverged"),
        "timeout": int(event == "timeout"),
    }

    return total_return, event, ep_eval, dict(comp_sums)


def evaluate_policy_zoo(env, policies, n_episodes=50, seed0=0, print_every=10):
    """
    policies: dict name -> policy_fn(env, obs) -> action
    returns: dict name -> aggregated stats
    """
    results = {}

    for name, pol in policies.items():
        print(f"\n--- Running policy: {name} ({n_episodes} episodes) ---")

        totals = []
        hits = 0
        min_dists = []
        times_to_hit = []
        max_gs = []
        violations = defaultdict(int)
        comp_totals = defaultdict(float)

        for i in range(n_episodes):
            ep_seed = seed0 + i
            total_return, event, ep_eval, comp_sums = run_episode(env, pol, seed=ep_seed)

            totals.append(total_return)
            hits += int(ep_eval["hit"])
            min_dists.append(ep_eval["min_dist"])
            max_gs.append(ep_eval["max_g"])

            if ep_eval["time_to_hit"] is not None:
                times_to_hit.append(ep_eval["time_to_hit"])

            violations["defense_ground"] += ep_eval["ground_defense"]
            violations["enemy_ground"] += ep_eval["ground_enemy"]
            violations["diverged"] += ep_eval["diverged"]
            violations["timeout"] += ep_eval["timeout"]

            for k, v in comp_sums.items():
                comp_totals[k] += float(v)

            # ---- progress print ----
            if (i + 1) % print_every == 0 or (i + 1) == n_episodes:
                hit_rate_so_far = hits / (i + 1)
                print(
                    f"  ep {i+1:4d}/{n_episodes} | seed {ep_seed:4d} | "
                    f"event={event:<14} | "
                    f"ret={total_return:9.2f} | "
                    f"min_d={ep_eval['min_dist']:9.1f} | "
                    f"HR={hit_rate_so_far*100:5.1f}%"
                )

        results[name] = {
            "hit_rate": hits / n_episodes,
            "return_mean": float(np.mean(totals)),
            "return_std": float(np.std(totals)),
            "min_dist_mean": float(np.mean(min_dists)),
            "min_dist_p50": float(np.median(min_dists)),
            "time_to_hit_mean": float(np.mean(times_to_hit)) if times_to_hit else None,
            "max_g_mean": float(np.mean(max_gs)),
            "violations": dict(violations),
            "reward_components_sum": dict(comp_totals),
        }

        print(f"--- Done: {name} | hit_rate={results[name]['hit_rate']*100:.1f}% | return_mean={results[name]['return_mean']:.2f} ---")

    return results


# -------------------------
# Pretty print + sanity ordering checks
# -------------------------

def print_policy_zoo_report(results):
    # rank by return_mean
    names_sorted = sorted(results.keys(), key=lambda n: results[n]["return_mean"], reverse=True)

    print("\n=== POLICY ZOO REPORT (sorted by mean return) ===")
    for n in names_sorted:
        r = results[n]
        print(f"\n[{n}]")
        print(f"  return_mean ± std : {r['return_mean']:.3f} ± {r['return_std']:.3f}")
        print(f"  hit_rate          : {r['hit_rate']*100:.1f}%")
        print(f"  min_dist_mean     : {r['min_dist_mean']:.1f} m | p50: {r['min_dist_p50']:.1f} m")
        print(f"  time_to_hit_mean  : {r['time_to_hit_mean'] if r['time_to_hit_mean'] is not None else 'None'}")
        print(f"  max_g_mean        : {r['max_g_mean']:.3f}")
        print(f"  violations        : {r['violations']}")
        print(f"  reward_components_sum: {r['reward_components_sum']}")

    # sanity checks (soft assertions)
    def get(name, key):
        return results[name][key] if name in results else None

    teacher = "teacher"
    anti = "anti_teacher"
    dn = "do_nothing"
    rnd = "random"

    warnings = []
    if teacher in results and anti in results:
        if results[teacher]["return_mean"] <= results[anti]["return_mean"]:
            warnings.append("Teacher does NOT beat Anti-teacher on return_mean -> reward likely wrong.")
        if results[teacher]["hit_rate"] <= results[anti]["hit_rate"]:
            warnings.append("Teacher does NOT beat Anti-teacher on hit_rate -> guidance or reward suspicious.")

    if teacher in results and dn in results:
        if results[teacher]["return_mean"] <= results[dn]["return_mean"]:
            warnings.append("Teacher does NOT beat Do-nothing on return_mean -> reward shaping might be inverted/weak.")

    if teacher in results and rnd in results:
        if results[teacher]["return_mean"] <= results[rnd]["return_mean"]:
            warnings.append("Teacher does NOT beat Random on return_mean -> reward is likely broken.")

    # Also: bad policies getting high returns is often due to shaping terms dominating terminal outcome
    # Flag if anti-teacher has surprisingly high return but terrible hit_rate
    if anti in results:
        if results[anti]["hit_rate"] < 0.05 and results[anti]["return_mean"] > 0.0:
            warnings.append("Anti-teacher has near-zero hit_rate but positive return -> reward hacking opportunity.")

    if warnings:
        print("\n=== WARNINGS ===")
        for w in warnings:
            print(" - " + w)
    else:
        print("\nNo ordering red flags detected (still inspect component sums).")



# -------------------------
# Local action-reward coupling test (Step 2 v3)
# -------------------------

def normalize_action(a: np.ndarray) -> np.ndarray:
    a = np.asarray(a, dtype=np.float32)
    a = np.clip(a, -1.0, 1.0)
    mag = float(np.linalg.norm(a))
    if mag > 1.0:
        a = a / mag
    return a

def rand_unit_2(rng: np.random.RandomState) -> np.ndarray:
    v = rng.randn(2).astype(np.float32)
    n = float(np.linalg.norm(v))
    if n < 1e-9:
        return np.array([1.0, 0.0], dtype=np.float32)
    return v / n


# -------------------------
# Policies (feedback)
# policy signature: policy(env, obs, rng) -> action
# -------------------------

def pol_teacher(env, obs, rng=None):
    return env.calculate_pronav()

def pol_anti_teacher(env, obs, rng=None):
    return -env.calculate_pronav()

def pol_do_nothing(env, obs, rng=None):
    return np.zeros(2, dtype=np.float32)

def pol_random(env, obs, rng):
    a = rng.uniform(-1.0, 1.0, size=(2,)).astype(np.float32)
    return normalize_action(a)

def make_pol_teacher_noisy(eps: float):
    # noise injected each step: a_t = norm(a*_t + eps * u_t)
    def _pol(env, obs, rng):
        a_star = env.calculate_pronav()
        u = rand_unit_2(rng)
        return normalize_action(a_star + eps * u)
    return _pol


# -------------------------
# Rollout utilities
# -------------------------

def rollout_k_steps_from_snapshot(env, snap, policy, K: int, rng: np.random.RandomState):
    """
    Restores snap, then rolls K steps using policy feedback.
    Returns:
      total_return, comp_sums, event_last, dist_last
    """
    env._restore(snap)
    obs = env._get_obs()

    total_r = 0.0
    comps = defaultdict(float)
    info_last = {}

    for _ in range(K):
        a = policy(env, obs, rng)
        obs, r, terminated, truncated, info = env.step(a)
        total_r += float(r)

        rt = info.get("reward_terms", {})
        for k, v in rt.items():
            comps[k] += float(v)

        info_last = info
        if terminated or truncated:
            break

    event = info_last.get("event", "unknown")
    dist = info_last.get("eval_step", {}).get("dist", None)
    return total_r, dict(comps), event, dist


def get_snapshot_from_trajectory(env, seed: int, snap_step: int, traj_policy, rng: np.random.RandomState, max_steps=20000):
    """
    Runs traj_policy until snap_step, returns (ok, snap, obs_at_snap).
    traj_policy is used only to generate the trajectory you sample the snapshot from.
    """
    obs, _ = env.reset(seed=seed)

    for t in range(min(snap_step, max_steps)):
        a = traj_policy(env, obs, rng)
        obs, r, terminated, truncated, info = env.step(a)
        if terminated or truncated:
            return False, None, None, info.get("event", "terminated_early")

    snap = env._snapshot()
    return True, snap, obs, "ok"


def find_valid_snapshot(env, base_seed: int, snap_step: int, snapshot_source: str, max_tries=20):
    """
    snapshot_source: "teacher" or "random"
    Tries multiple seeds until the trajectory survives to snap_step.
    """
    for k in range(max_tries):
        seed = base_seed + k
        traj_rng = np.random.RandomState(10_000 + seed)

        if snapshot_source == "teacher":
            traj_policy = pol_teacher
        elif snapshot_source == "random":
            traj_policy = lambda e, o, rr: pol_random(e, o, rr)
        else:
            raise ValueError("snapshot_source must be 'teacher' or 'random'")

        ok, snap, obs, reason = get_snapshot_from_trajectory(
            env, seed=seed, snap_step=snap_step, traj_policy=traj_policy, rng=traj_rng
        )
        if ok:
            return True, seed, snap
    return False, None, None


# -------------------------
# Main local coupling test (v3)
# -------------------------

def local_action_reward_sanity_v3(
    env,
    base_seed=0,
    snap_step=50,
    K=20,
    snapshot_source="random",         # "teacher" or "random"
    eps_list=(0.05, 0.15, 0.30),
    n_noisy=64,
    n_random=64,
    print_top_bottom=8,
):
    """
    Compares short-horizon K-step returns starting from the exact same snapshot:
      - teacher feedback baseline
      - anti-teacher feedback
      - do-nothing
      - random feedback
      - teacher+noise feedback (per-step noise) for multiple eps magnitudes

    Reports advantage distribution vs teacher: A = R - R_teacher.
    """

    ok, seed_used, snap = find_valid_snapshot(env, base_seed, snap_step, snapshot_source=snapshot_source)
    if not ok:
        print(f"[FAIL] Could not find valid snapshot: base_seed={base_seed}, snap_step={snap_step}, source={snapshot_source}")
        return {"ok": False}

    # baseline teacher feedback
    rng_base = np.random.RandomState(12345)
    R_teacher, C_teacher, evT, distT = rollout_k_steps_from_snapshot(env, snap, pol_teacher, K, rng_base)

    # evaluate some fixed baselines
    baselines = {
        "teacher": pol_teacher,
        "anti_teacher": pol_anti_teacher,
        "do_nothing": pol_do_nothing,
    }

    records = []  # (name, R, A, comps, event)
    for name, pol in baselines.items():
        rng = np.random.RandomState(20000 + hash(name) % 10_000)
        R, comps, ev, dist = rollout_k_steps_from_snapshot(env, snap, pol, K, rng)
        records.append((name, R, R - R_teacher, comps, ev))

    # random policies
    for j in range(n_random):
        rng = np.random.RandomState(30000 + j)
        R, comps, ev, dist = rollout_k_steps_from_snapshot(env, snap, lambda e,o,rr: pol_random(e,o,rr), K, rng)
        records.append((f"random_{j}", R, R - R_teacher, comps, ev))

    # teacher + per-step noise at different eps
    for eps in eps_list:
        pol_noisy = make_pol_teacher_noisy(eps)
        for j in range(n_noisy):
            rng = np.random.RandomState(40000 + int(1000*eps) + j)
            R, comps, ev, dist = rollout_k_steps_from_snapshot(env, snap, pol_noisy, K, rng)
            records.append((f"teacher_noise_eps{eps:.2f}_{j}", R, R - R_teacher, comps, ev))

    # summarize by groups
    def group_of(name: str):
        if name == "teacher": return "teacher"
        if name == "anti_teacher": return "anti_teacher"
        if name == "do_nothing": return "do_nothing"
        if name.startswith("random_"): return "random"
        if name.startswith("teacher_noise_"): return "teacher+noise"
        return "other"

    groups = defaultdict(list)
    for name, R, A, comps, ev in records:
        groups[group_of(name)].append((name, R, A, comps, ev))

    def summarize_adv(items):
        adv = np.array([x[2] for x in items], dtype=np.float64)
        return {
            "n": int(len(adv)),
            "mean": float(np.mean(adv)),
            "p50": float(np.median(adv)),
            "p90": float(np.percentile(adv, 90)),
            "max": float(np.max(adv)),
            "min": float(np.min(adv)),
        }

    # print report
    print("\n" + "="*78)
    print("=== LOCAL ACTION-REWARD SANITY v3 (feedback + advantage dist) ===")
    print(f"snapshot_source={snapshot_source} | seed_used={seed_used} | snap_step={snap_step} | K={K}")
    print(f"R_teacher = {R_teacher:.6f} | event_last={evT}")
    print("-"*78)

    for gname in ["teacher", "teacher+noise", "random", "do_nothing", "anti_teacher"]:
        if gname not in groups: 
            continue
        s = summarize_adv(groups[gname])
        print(f"{gname:12s}  A=R-R_teacher: mean={s['mean']:+.6f}  p50={s['p50']:+.6f}  p90={s['p90']:+.6f}  max={s['max']:+.6f}  min={s['min']:+.6f}  n={s['n']}")

    # top/bottom by advantage (who beats teacher most / loses most)
    records_sorted = sorted(records, key=lambda x: x[2], reverse=True)

    print("\nTop candidates by advantage (beating teacher):")
    for name, R, A, comps, ev in records_sorted[:print_top_bottom]:
        # component deltas (against teacher)
        d_prog = comps.get("r_progress", 0.0) - C_teacher.get("r_progress", 0.0)
        d_close = comps.get("r_close", 0.0) - C_teacher.get("r_close", 0.0)
        print(f"  {name:28s}  A={A:+.6f}  R={R:.6f}  Δprog={d_prog:+.4f}  Δclose={d_close:+.4f}  ev={ev}")

    print("\nBottom candidates by advantage (worst):")
    for name, R, A, comps, ev in records_sorted[-print_top_bottom:]:
        d_prog = comps.get("r_progress", 0.0) - C_teacher.get("r_progress", 0.0)
        d_close = comps.get("r_close", 0.0) - C_teacher.get("r_close", 0.0)
        print(f"  {name:28s}  A={A:+.6f}  R={R:.6f}  Δprog={d_prog:+.4f}  Δclose={d_close:+.4f}  ev={ev}")

    # verdict logic:
    # - "good": random p90 advantage is negative (most random are worse than teacher)
    # - "warning": random p90 advantage positive (many random beat teacher)
    verdict = "PASS"
    if "random" in groups:
        sR = summarize_adv(groups["random"])
        if sR["p90"] > 0.0:
            verdict = "WARNING: random p90 advantage > 0 (many random beat teacher locally)"
    if "teacher+noise" in groups:
        sN = summarize_adv(groups["teacher+noise"])
        # if mild noise often beats teacher, basin is very flat
        if sN["p90"] > 0.0 and verdict == "PASS":
            verdict = "WARNING: teacher+noise p90 advantage > 0 (basin very flat / teacher not locally optimal for this shaping)"

    print("\nVERDICT:", verdict)
    print("="*78)

    return {
        "ok": True,
        "snapshot_source": snapshot_source,
        "seed_used": seed_used,
        "snap_step": snap_step,
        "K": K,
        "R_teacher": R_teacher,
        "C_teacher": C_teacher,
        "groups_summary": {k: summarize_adv(v) for k, v in groups.items()},
        "records": records,
        "verdict": verdict,
    }


# -------------------------
# Coupling curve: advantage vs epsilon (win_rate + moments)
# -------------------------

def rollout_feedback_teacher(env, snap, K):
    env._restore(snap)
    total_r = 0.0
    comps = defaultdict(float)
    ev_last = "running"

    obs = env._get_obs()
    for _ in range(K):
        a = env.calculate_pronav()
        obs, r, terminated, truncated, info = env.step(a)
        total_r += float(r)
        ev_last = info.get("event", ev_last)

        rt = info.get("reward_terms", {})
        for k, v in rt.items():
            comps[k] += float(v)

        if terminated or truncated:
            break

    return total_r, dict(comps), ev_last


def rollout_feedback_teacher_plus_noise(env, snap, K, eps, rng: np.random.RandomState):
    env._restore(snap)
    total_r = 0.0
    comps = defaultdict(float)
    ev_last = "running"

    obs = env._get_obs()
    for _ in range(K):
        a_star = env.calculate_pronav()

        noise = rng.normal(size=a_star.shape).astype(np.float32)
        noise /= (float(np.linalg.norm(noise)) + 1e-8)
        a = a_star + float(eps) * noise
        a = normalize_action(a)

        obs, r, terminated, truncated, info = env.step(a)
        total_r += float(r)
        ev_last = info.get("event", ev_last)

        rt = info.get("reward_terms", {})
        for k, v in rt.items():
            comps[k] += float(v)

        if terminated or truncated:
            break

    return total_r, dict(comps), ev_last


def win_stats(A):
    A = np.asarray(A, dtype=np.float64)
    if A.size == 0:
        return {
            "n": 0,
            "win_rate": 0.0,
            "mean": 0.0,
            "p50": 0.0,
            "p90": 0.0,
            "max": 0.0,
            "min": 0.0,
            "mean_pos": 0.0,
            "mean_neg": 0.0,
        }

    win = A > 0
    return {
        "n": int(A.size),
        "win_rate": float(np.mean(win)),
        "mean": float(np.mean(A)),
        "p50": float(np.percentile(A, 50)),
        "p90": float(np.percentile(A, 90)),
        "max": float(np.max(A)),
        "min": float(np.min(A)),
        "mean_pos": float(np.mean(A[win])) if np.any(win) else 0.0,
        "mean_neg": float(np.mean(A[~win])) if np.any(~win) else 0.0,
    }


def local_coupling_curve(
    env,
    snapshot_source="random",
    base_seed=0,
    snap_step=50,
    K=50,
    eps_list=(0.05, 0.10, 0.15, 0.30),
    n_trials=256,
    max_tries=20,
):
    """
    Build a snapshot (from random or teacher trajectory), then measure advantage
    distribution A = R_noisy - R_teacher for teacher+noise feedback policies.

    Uses a consistent restore+truncate snapshot to compare apples-to-apples.
    """
    ok, seed_used, snap = find_valid_snapshot(env, base_seed, snap_step, snapshot_source=snapshot_source, max_tries=max_tries)
    if not ok:
        print(f"[FAIL] Could not find valid snapshot: base_seed={base_seed}, snap_step={snap_step}, source={snapshot_source}")
        return {"ok": False}

    # teacher baseline (feedback)
    R_teacher, comps_teacher, ev_teacher = rollout_feedback_teacher(env, snap, K)

    A_by_eps = {}
    for eps in eps_list:
        A = []
        for j in range(n_trials):
            rng = np.random.RandomState(12345 + 100000 * int(100 * eps) + j)
            R_noisy, comps_noisy, ev_noisy = rollout_feedback_teacher_plus_noise(env, snap, K, eps, rng)
            A.append(R_noisy - R_teacher)
        A_by_eps[float(eps)] = A

    print("\n=== LOCAL COUPLING CURVE ===")
    print(f"source={snapshot_source} seed_used={seed_used} snap_step={snap_step} K={K}")
    print(f"R_teacher={R_teacher:.6f} event_last={ev_teacher}")

    means = []
    for eps in eps_list:
        s = win_stats(A_by_eps[float(eps)])
        means.append(s["mean"])
        print(
            f"eps={eps:0.2f} | win_rate={s['win_rate']:.3f} | mean={s['mean']:+.4f} | "
            f"p50={s['p50']:+.4f} | p90={s['p90']:+.4f} | max={s['max']:+.4f} | min={s['min']:+.4f}"
        )

    means = np.asarray(means, dtype=np.float64)
    monotone_more_negative = bool(np.all(np.diff(means) <= 1e-6))
    print(f"monotone_more_negative_with_eps: {monotone_more_negative}")

    return {
        "ok": True,
        "snapshot_source": snapshot_source,
        "seed_used": seed_used,
        "snap_step": snap_step,
        "K": K,
        "R_teacher": float(R_teacher),
        "A_by_eps": A_by_eps,
    }


def local_one_step_coupling_curve(
    env,
    snapshot_source="teacher",   # "teacher" or "random"
    base_seed=0,
    snap_step=50,
    n_trials=256,
    eps_list=(0.05, 0.10, 0.15, 0.30),
):
    rng = np.random.RandomState(base_seed)

    # Collect deltas: reward(noisy) - reward(base)
    deltas = {eps: [] for eps in eps_list}

    for i in range(n_trials):
        obs, _ = env.reset(seed=base_seed + i)

        # Roll forward to snap_step using teacher or random
        for _ in range(snap_step):
            if snapshot_source == "teacher":
                a = env.calculate_pronav()
            else:
                a = env.action_space.sample()
            obs, r, term, trunc, info = env.step(a)
            if term or trunc:
                break

        if env.done:
            continue  # skip early-terminated episodes

        snap = env._snapshot()

        # Base action at snapshot
        if snapshot_source == "teacher":
            a_base = env.calculate_pronav()
        else:
            a_base = env.action_space.sample()

        # Base one-step reward
        env._restore(snap)
        _, r_base, term, trunc, info_base = env.step(a_base)
        if term or trunc:
            continue

        # Perturbed actions
        for eps in eps_list:
            noise = rng.normal(loc=0.0, scale=eps, size=(2,)).astype(np.float32)
            a_noisy = np.clip(a_base + noise, -1.0, 1.0).astype(np.float32)

            env._restore(snap)
            _, r_noisy, term, trunc, info_noisy = env.step(a_noisy)
            if term or trunc:
                continue

            deltas[eps].append(float(r_noisy - r_base))

    # Print summary like your other curves
    print("\n=== ONE-STEP LOCAL COUPLING CURVE ===")
    print(f"source={snapshot_source} snap_step={snap_step} n_trials={n_trials}")
    prev_mean = None
    monotone = True

    for eps in eps_list:
        arr = np.array(deltas[eps], dtype=np.float64)
        if len(arr) == 0:
            print(f"eps={eps:.2f} | no data")
            monotone = False
            continue
        mean = float(arr.mean())
        p50 = float(np.median(arr))
        p90 = float(np.percentile(arr, 90))
        mx = float(arr.max())
        mn = float(arr.min())
        win_rate = float((arr < 0.0).mean())  # fraction noise makes reward worse

        print(
            f"eps={eps:.2f} | win_rate={win_rate:.3f} | mean={mean:+.4f} | "
            f"p50={p50:+.4f} | p90={p90:+.4f} | max={mx:+.4f} | min={mn:+.4f}"
        )

        if prev_mean is not None and mean > prev_mean + 1e-6:
            monotone = False
        prev_mean = mean

    print(f"monotone_more_negative_with_eps: {monotone}")
    return deltas





def local_one_step_coupling_curve_teacher(
    env,
    base_seed=0,
    snap_step=50,
    n_trials=256,
    eps_list=(0.05, 0.10, 0.15, 0.30),
):
    rng = np.random.RandomState(123)

    deltas = {eps: [] for eps in eps_list}

    for i in range(n_trials):
        obs, _ = env.reset(seed=base_seed + i)

        # roll to snap_step using TEACHER (important)
        for _ in range(snap_step):
            a = env.calculate_pronav()
            obs, r, term, trunc, info = env.step(a)
            if term or trunc:
                break
        if env.done:
            continue

        snap = env._snapshot()

        # base teacher action at snapshot
        env._restore(snap)
        a_base = normalize_action(env.calculate_pronav())

        # base one-step reward
        env._restore(snap)
        _, r_base, term, trunc, _ = env.step(a_base)
        if term or trunc:
            continue

        for eps in eps_list:
            u = rand_unit_2(rng)              # unit direction
            a_noisy = normalize_action(a_base + float(eps) * u)

            env._restore(snap)
            _, r_noisy, term, trunc, _ = env.step(a_noisy)
            if term or trunc:
                continue

            deltas[eps].append(float(r_noisy - r_base))

    print("\n=== ONE-STEP COUPLING (teacher-centered) ===")
    prev_mean = None
    monotone = True
    for eps in eps_list:
        arr = np.array(deltas[eps], dtype=np.float64)
        if arr.size == 0:
            print(f"eps={eps:.2f} | no data")
            monotone = False
            continue
        mean = float(arr.mean())
        p50 = float(np.median(arr))
        p90 = float(np.percentile(arr, 90))
        win_rate = float((arr < 0.0).mean())  # noise makes reward worse
        print(f"eps={eps:.2f} | win_rate={win_rate:.3f} | mean={mean:+.6f} | p50={p50:+.6f} | p90={p90:+.6f}")

        if prev_mean is not None and mean > prev_mean + 1e-9:
            monotone = False
        prev_mean = mean

    print(f"monotone_more_negative_with_eps: {monotone}")
    return deltas
# -------------------------
# Main entry: run the zoo
# -------------------------

def run_policy_zoo(student_model=None, n_episodes=50, seed0=0, print_every=10):
    env = missile_interception_3d()

    student = StudentPolicy(model=student_model, deterministic=True)

    policies = {
        "teacher": policy_teacher,
        "do_nothing": policy_do_nothing,
        "random": policy_random,
        "anti_teacher": policy_anti_teacher,
        "student": student.__call__,
    }

    print("Evaluating policy zoo...")
    results = evaluate_policy_zoo(env, policies, n_episodes=n_episodes, seed0=seed0, print_every=print_every)
    print_policy_zoo_report(results)
    return results


if __name__ == "__main__":
    run_policy_zoo(student_model=None, n_episodes=50, seed0=0)


Evaluating policy zoo...

--- Running policy: teacher (50 episodes) ---
  ep   10/50 | seed    9 | event=hit            | ret= 11025.86 | min_d=    127.3 | HR=100.0%
  ep   20/50 | seed   19 | event=hit            | ret= 10702.13 | min_d=    134.1 | HR=100.0%
  ep   30/50 | seed   29 | event=hit            | ret= 11591.47 | min_d=    143.0 | HR=100.0%
  ep   40/50 | seed   39 | event=hit            | ret= 11484.13 | min_d=    135.1 | HR=100.0%
  ep   50/50 | seed   49 | event=hit            | ret= 11186.65 | min_d=    115.7 | HR=100.0%
--- Done: teacher | hit_rate=100.0% | return_mean=11301.01 ---

--- Running policy: do_nothing (50 episodes) ---
  ep   10/50 | seed    9 | event=enemy_ground   | ret= -2955.92 | min_d=  46595.9 | HR=  0.0%
  ep   20/50 | seed   19 | event=enemy_ground   | ret= -3108.64 | min_d=  29574.8 | HR=  0.0%
  ep   30/50 | seed   29 | event=enemy_ground   | ret= -3570.39 | min_d=  60892.3 | HR=  0.0%
  ep   40/50 | seed   39 | event=enemy_ground   | ret= -4123.06

In [18]:
# ==========================================
# LOCAL ACTION-REWARD COUPLING TEST (Step 2 v3)
# ==========================================
# This tests whether the reward provides good local gradients
# around the teacher action, which is critical for PPO learning.
# Uses feedback policies (recomputes action each step) and reports
# advantage distributions vs teacher.

env = missile_interception_3d()

print("Testing local coupling (feedback policies) ...")
print("=" * 70)

for snapshot_source in ["random", "teacher"]:
    for K in [20, 50]:
        for snap_step in [20, 50, 80]:
            out = local_action_reward_sanity_v3(
                env,
                base_seed=0,
                snap_step=snap_step,
                K=K,
                snapshot_source=snapshot_source,
                eps_list=(0.05, 0.15, 0.30),
                n_noisy=64,
                n_random=64,
                print_top_bottom=6,
            )
            if not out.get("ok", False):
                print(f"\nWARNING: Test failed for snapshot_source={snapshot_source}, K={K}, snap_step={snap_step}")
            print()


Testing local coupling (feedback policies) ...

=== LOCAL ACTION-REWARD SANITY v3 (feedback + advantage dist) ===
snapshot_source=random | seed_used=0 | snap_step=20 | K=20
R_teacher = 62.875733 | event_last=running
------------------------------------------------------------------------------
teacher       A=R-R_teacher: mean=+0.000000  p50=+0.000000  p90=+0.000000  max=+0.000000  min=+0.000000  n=1
teacher+noise  A=R-R_teacher: mean=-0.028105  p50=-0.017502  p90=+0.132971  max=+0.337522  min=-0.453564  n=192
random        A=R-R_teacher: mean=-3.556433  p50=-3.677414  p90=-2.220628  max=-1.541327  min=-5.706475  n=64
do_nothing    A=R-R_teacher: mean=-3.733072  p50=-3.733072  p90=-3.733072  max=-3.733072  min=-3.733072  n=1
anti_teacher  A=R-R_teacher: mean=-7.764379  p50=-7.764379  p90=-7.764379  max=-7.764379  min=-7.764379  n=1

Top candidates by advantage (beating teacher):
  teacher_noise_eps0.30_16      A=+0.337522  R=63.213255  Δprog=+0.3273  Δclose=+0.0143  ev=running
  teache

In [14]:
# # ==========================================
# # LOCAL COUPLING CURVE (advantage vs epsilon)
# # ==========================================
# # This is the "PPO usability" test: does expected return drop as you add noise?

# env = missile_interception_3d()

# for source in ["random", "teacher"]:
#     for snap_step in [20, 50, 80]:
#         _ = local_coupling_curve(
#             env,
#             snapshot_source=source,
#             base_seed=0,
#             snap_step=snap_step,
#             K=50,
#             eps_list=(0.05, 0.10, 0.15, 0.30),
#             n_trials=256,
#             max_tries=20,
#         )
#         print()


In [15]:
env = missile_interception_3d()

for source in ["random", "teacher"]:
    for snap_step in [20, 50, 80]:
        _ = local_one_step_coupling_curve(
            env,
            snapshot_source=source,
            base_seed=0,
            snap_step=snap_step,
            n_trials=256,
            eps_list=(0.05, 0.10, 0.15, 0.30),
        )
        print()



=== ONE-STEP LOCAL COUPLING CURVE ===
source=random snap_step=20 n_trials=256
eps=0.05 | win_rate=0.469 | mean=+0.0000 | p50=+0.0000 | p90=+0.0002 | max=+0.0006 | min=-0.0005
eps=0.10 | win_rate=0.512 | mean=-0.0001 | p50=-0.0000 | p90=+0.0003 | max=+0.0011 | min=-0.0016
eps=0.15 | win_rate=0.496 | mean=+0.0000 | p50=+0.0000 | p90=+0.0006 | max=+0.0038 | min=-0.0028
eps=0.30 | win_rate=0.465 | mean=+0.0001 | p50=+0.0000 | p90=+0.0014 | max=+0.0032 | min=-0.0049
monotone_more_negative_with_eps: False


=== ONE-STEP LOCAL COUPLING CURVE ===
source=random snap_step=50 n_trials=256
eps=0.05 | win_rate=0.488 | mean=+0.0000 | p50=+0.0000 | p90=+0.0002 | max=+0.0008 | min=-0.0009
eps=0.10 | win_rate=0.516 | mean=-0.0000 | p50=-0.0000 | p90=+0.0003 | max=+0.0017 | min=-0.0012
eps=0.15 | win_rate=0.492 | mean=+0.0000 | p50=+0.0000 | p90=+0.0005 | max=+0.0026 | min=-0.0021
eps=0.30 | win_rate=0.449 | mean=+0.0000 | p50=+0.0000 | p90=+0.0010 | max=+0.0053 | min=-0.0044
monotone_more_negative_wit