In [17]:
import gymnasium as gym
import numpy as np

class missile_interception_3d(gym.Env):
    def __init__(self):
        # 1. Define Action Space (The Joystick: Left/Right, Up/Down)
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2, ), dtype=np.float32)
        
        # 2. Define Observation Space (16D ego-frame, no actuator state)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)

        self.np_random = np.random.RandomState()
        
        # 3. Time Settings
        self.dt_act = 0.1             
        self.n_substeps = 10          
        self.dt_sim = self.dt_act / self.n_substeps 
        self.t_max = 650.0 # REVISAR IF THIS IS ENOF         

        # 4. Physical Limits
        self.a_max = 350.0   # Max G-force (m/s^2) ~35G
        self.g = 9.81        
        self.collision_radius = 150.0  
        self.max_distance = 4_000_000.0 

        # LET'S REMOVE THE "HARD CASES"

        self.p_easy = 1.0                   
        self.range_min = 70_000.0           
        self.range_easy_max = 200_000.0     
        self.range_hard_max = 1_000_000.0   

        self.targetbox_x_min = -15000
        self.targetbox_x_max = 15000
        self.targetbox_y_min = -15000
        self.targetbox_y_max = 15000

        # --- Reward shaping weights ---
        self.w_phi = 1.0          # main: absolute phi quality (0..1)
        self.w_ca = 0.1           # garnish: progress (phi_after - phi_before)
        self.w_close = 0.08     # tiny hint
        self.gamma_shape = 1.0
        self.dstar_scale = 10_000.0

        # terminal shaping
        self.r_hit_bonus = 400.0
        self.r_def_ground_pen = 200.0
        self.r_timeout_pen = 0.0
        self.r_diverged_pen = 1.0

        # ESTOS TOCA MIRARLOS GONORREA

    def generate_enemy_missile(self):
        if self.np_random.rand() < self.p_easy:
            self.range_max_used = self.range_easy_max
        else:
            self.range_max_used = self.range_hard_max

        range_min = self.range_min
        self.attack_target_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.attack_target_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)
        self.enemy_launch_angle = self.np_random.uniform(0, 2 * np.pi)
        self.enemy_theta = self.np_random.uniform(0.523599, 1.0472) 

        self.range_max_used = max(self.range_max_used, range_min + 1.0)
        lower_limit = np.sqrt((range_min * self.g) / np.sin(2 * self.enemy_theta))
        upper_limit = np.sqrt((self.range_max_used * self.g) / np.sin(2 * self.enemy_theta))
        self.enemy_initial_velocity = self.np_random.uniform(lower_limit, upper_limit)

        ground_range = (
            self.enemy_initial_velocity * np.cos(self.enemy_theta)
            * (2 * self.enemy_initial_velocity * np.sin(self.enemy_theta) / self.g)
        )

        self.enemy_launch_x = self.attack_target_x + ground_range * np.cos(self.enemy_launch_angle)
        self.enemy_launch_y = self.attack_target_y + ground_range * np.sin(self.enemy_launch_angle)
        self.enemy_z = 0
        self.enemy_x = self.enemy_launch_x
        self.enemy_y = self.enemy_launch_y
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.enemy_azimuth = (self.enemy_launch_angle + np.pi) % (2 * np.pi)

    def generate_defense_missile(self):
        self.defense_launch_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.defense_launch_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)

        dx = self.enemy_launch_x - self.defense_launch_x
        dy = self.enemy_launch_y - self.defense_launch_y
        az_nominal = np.arctan2(dy, dx)

        # --- Misalignment (domain randomization of initial heading) ---
        # Mixture: most episodes small error, some episodes large az error
        p_misaligned = 0.35  # 35% "hard" starts
        if self.np_random.rand() < p_misaligned:
            # Hard: big azimuth error → strong RIGHT required
            az_noise = self.np_random.uniform(-np.deg2rad(60.0), np.deg2rad(60.0))
        else:
            # Easy: small azimuth error → gentle correction
            az_noise = self.np_random.uniform(-np.deg2rad(10.0), np.deg2rad(10.0))

        self.defense_azimuth = az_nominal + az_noise

        # Elevation noise: avoid always same vertical plane
        theta_nominal = 0.785398  # ~45 deg
        theta_noise_deg = 10.0
        theta_noise = self.np_random.uniform(
            -np.deg2rad(theta_noise_deg),
            +np.deg2rad(theta_noise_deg),
        )
        self.defense_theta = float(np.clip(theta_nominal + theta_noise,
                                           np.deg2rad(10.0),
                                           np.deg2rad(80.0)))
        # ---------------------------------------------------------------

        base_velocity = 3000.0
        if hasattr(self, 'range_max_used'):
            velocity_scale = min(self.range_max_used / self.range_easy_max, 1.5)
            self.defense_initial_velocity = base_velocity * velocity_scale
        else:
            self.defense_initial_velocity = base_velocity

        self.defense_x = self.defense_launch_x
        self.defense_y = self.defense_launch_y
        self.defense_z = 0.0
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
    
    def _smoothstep(self, x: float) -> float:
        """Smooth ramp 0->1 with zero slope at ends, clamps outside [0,1]"""
        x = float(np.clip(x, 0.0, 1.0))
        return x * x * (3.0 - 2.0 * x)

    def _phi_reward(self, phi: float, dstar: float) -> float:
        """Rational shaping: 1/(1+(dstar/d0)^p). Always in (0,1], no underflow."""
        d0 = 50_000.0
        p = 2.0
        x = float(dstar) / d0
        return float(1.0 / (1.0 + x**p))
    
    def calculate_pronav(self):
        eps = 1e-9

        # Relative geometry (use float64 for stability)
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        v = self.defense_vel.astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)

        R = float(np.linalg.norm(r)) + eps
        V = float(np.linalg.norm(v)) + eps

        rhat = r / R
        vhat = v / V

        # Heading error alpha = angle between velocity direction and LOS direction
        cosang = float(np.clip(np.dot(vhat, rhat), -1.0, 1.0))
        alpha = float(np.arccos(cosang))  # radians

        # LOS angular rate omega (world frame)
        omega = np.cross(r, vrel) / (float(np.dot(r, r)) + eps)
        omega_mag = float(np.linalg.norm(omega))

        # Closing speed (positive => closing)
        vc = -float(np.dot(r, vrel)) / R

        # --- PN term ---
        N = 3.0
        a_pn = N * vc * np.cross(omega, rhat)  # lateral accel in world frame

        # --- Acquisition term (turn-to-LOS) ---
        # Perpendicular component of LOS relative to forward direction
        rhat_perp = rhat - float(np.dot(rhat, vhat)) * vhat
        nperp = float(np.linalg.norm(rhat_perp))

        if nperp < 1e-8:
            a_acq = np.zeros(3, dtype=np.float64)
        else:
            rhat_perp /= nperp  # unit sideways "turn toward LOS" direction

            # Curvature-based magnitude: ~k * V^2 / R, saturate later via a_max
            k_acq = 5.0  # try 3.0–8.0
            a_acq = k_acq * (V * V / R) * rhat_perp

        # --- Blend weight w: 0 => pure PN, 1 => pure acquisition ---

        # Alpha-based weight (dominant)
        alpha_on   = np.deg2rad(20.0)   # start blending earlier
        alpha_full = np.deg2rad(55.0)

        x_alpha = (alpha - alpha_on) / (alpha_full - alpha_on + eps)
        w_alpha = self._smoothstep(x_alpha)

        # Omega-based modifier (only boosts acquisition when PN is sleepy)
        omega_full = 0.00
        omega_on   = 0.05   # <-- key: less brittle than 0.02

        x_omega = (omega_on - omega_mag) / (omega_on - omega_full + eps)
        w_omega = self._smoothstep(x_omega)

        # Robust combine: alpha dominates; omega can't fully shut it off
        w = w_alpha * (0.25 + 0.75 * w_omega)

        # Optional: if not closing, force strong acquisition
        if vc <= 0.0:
            w = max(w, 0.9)

        a_ideal = (1.0 - w) * a_pn + w * a_acq

        # Project into your lateral control basis (right/up) and normalize by a_max
        # Note: Environment now handles gravity compensation internally,
        # so ProNav outputs desired NET lateral accel (same semantics as PPO)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)
        a_right = float(np.dot(a_ideal, right))
        a_up    = float(np.dot(a_ideal, up))

        action = np.array([a_right / self.a_max, a_up / self.a_max], dtype=np.float32)
        return np.clip(action, -1.0, 1.0)
    
    def _segment_sphere_intersect(self, r0, r1, r_hit):
        dr = r1 - r0
        dr_norm_sq = float(np.dot(dr, dr))
        if dr_norm_sq < 1e-12:
            return float(np.dot(r0, r0)) <= r_hit * r_hit
        s_star = -float(np.dot(r0, dr)) / dr_norm_sq
        s_star = max(0.0, min(1.0, s_star))
        r_closest = r0 + s_star * dr
        return float(np.dot(r_closest, r_closest)) <= r_hit * r_hit
    
    def _phi_closest_approach(self):
        """
        Potential based on closest-approach miss distance under
        constant relative velocity assumption.
        Returns:
            phi, d_star, t_star
        """
        eps = 1e-9

        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)

        vrel_norm_sq = float(np.dot(vrel, vrel)) + eps

        # t* = argmin ||r + vrel t|| for t >= 0
        t_star = -float(np.dot(r, vrel)) / vrel_norm_sq
        t_star = max(0.0, t_star)

        m_star = r + vrel * t_star
        d_star = float(np.linalg.norm(m_star))

        phi = -d_star / (self.dstar_scale + eps)
        return float(phi), d_star, float(t_star)
    
    def _get_obs(self):
        eps = 1e-9

        # World-frame relative state
        r_world = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel_world = (self.enemy_vel - self.defense_vel).astype(np.float64)

        # Local basis from defense velocity (world frame unit vectors)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)

        # ===============================
        # 1) Ego-frame (body-frame) r and vrel
        # ===============================
        r_body = np.array([
            float(np.dot(r_world, forward)),
            float(np.dot(r_world, right)),
            float(np.dot(r_world, up)),
        ], dtype=np.float64)

        vrel_body = np.array([
            float(np.dot(vrel_world, forward)),
            float(np.dot(vrel_world, right)),
            float(np.dot(vrel_world, up)),
        ], dtype=np.float64)

        # Normalize r_body / vrel_body (easy range only)
        pos_scale = float(self.range_easy_max)
        vel_scale = 4000.0

        r_body_n = (r_body / (pos_scale + eps)).astype(np.float32)
        vrel_body_n = (vrel_body / (vel_scale + eps)).astype(np.float32)

        # ===============================
        # 2) Scalar helpers
        # ===============================
        dist = float(np.linalg.norm(r_world)) + 1e-6
        v_close = -float(np.dot(r_world, vrel_world)) / dist  # positive when closing

        dist_n = np.float32(np.clip(dist / 1_000_000.0, 0.0, 4.0))
        vclose_n = np.float32(np.clip(v_close / 3000.0, -2.0, 2.0))
        dist_vclose_feat = np.array([dist_n, vclose_n], dtype=np.float32)

        # Defense own vertical state (keep for ground constraint)
        def_z_n = np.float32(np.clip(self.defense_pos[2] / 100_000.0, -1.0, 2.0))
        def_vz_n = np.float32(np.clip(self.defense_vel[2] / 3000.0, -2.0, 2.0))
        def_state_feat = np.array([def_z_n, def_vz_n], dtype=np.float32)

        # ===============================
        # 4) Keep your geometry features (consistent with ego-frame)
        # ===============================
        dist_body = float(np.linalg.norm(r_body)) + 1e-6

        # LOS lateral projections in body frame
        los_right = float(r_body[1] / dist_body)
        los_up    = float(r_body[2] / dist_body)

        # LOS rate omega in body frame: omega = (r x vrel)/||r||^2
        dist2_body = float(np.dot(r_body, r_body)) + eps
        omega_body = np.cross(r_body, vrel_body) / dist2_body

        omega_right = float(omega_body[1])
        omega_up    = float(omega_body[2])

        omega_scale = 2.0
        omega_right_n = float(np.clip(omega_right / omega_scale, -2.0, 2.0))
        omega_up_n    = float(np.clip(omega_up / omega_scale, -2.0, 2.0))

        geom_feat = np.array([los_right, los_up, omega_right_n, omega_up_n], dtype=np.float32)

        # ===============================
        # 5) NEW: kinematics garnish
        # ===============================
        V_def = float(np.linalg.norm(self.defense_vel))
        V_def_n = np.float32(np.clip(V_def / 3000.0, 0.0, 3.0))  # scale: 3000 m/s baseline
        forward_z = np.float32(float(forward[2]))               # dot(forward, world_up) since world_up=[0,0,1]

        kin_feat = np.array([V_def_n, forward_z], dtype=np.float32)

        # Final obs (16D, no actuator state)
        obs = np.concatenate(
            [r_body_n, vrel_body_n, dist_vclose_feat, def_state_feat, geom_feat, kin_feat],
            axis=0
        ).astype(np.float32)

        return obs

    def _compute_lateral_basis(self, velocity):
        """
        Horizon-stable basis:
          forward = along velocity
          right   = world_up x forward  (horizontal right)
          up      = forward x right     (completes orthonormal frame)
        This keeps 'up' as close to world-up as possible and avoids weird twisting.
        """
        speed = float(np.linalg.norm(velocity))
        if speed < 1.0:
            forward = np.array([1.0, 0.0, 0.0], dtype=np.float32)
        else:
            forward = (velocity / speed).astype(np.float32)

        world_up = np.array([0.0, 0.0, 1.0], dtype=np.float32)

        # right = world_up x forward
        right_raw = np.cross(world_up, forward)
        rnorm = float(np.linalg.norm(right_raw))

        # If forward is near world_up, right_raw ~ 0. Pick a consistent fallback.
        if rnorm < 1e-6:
            # Choose a fixed "north" axis in world XY and build right from that
            # This prevents random spinning when vertical.
            north = np.array([1.0, 0.0, 0.0], dtype=np.float32)
            right_raw = np.cross(north, forward)
            rnorm = float(np.linalg.norm(right_raw))
            if rnorm < 1e-6:
                north = np.array([0.0, 1.0, 0.0], dtype=np.float32)
                right_raw = np.cross(north, forward)
                rnorm = float(np.linalg.norm(right_raw))

        right = (right_raw / (rnorm + 1e-9)).astype(np.float32)

        # up = forward x right (not right x forward)
        up_raw = np.cross(forward, right)
        up = (up_raw / (float(np.linalg.norm(up_raw)) + 1e-9)).astype(np.float32)

        return forward, right, up

    def step(self, action, skip_projection=False):
        if getattr(self, "done", False):
            return self._get_obs(), 0.0, True, False, {"event": "called_step_after_done", "dist": float(np.linalg.norm(self.enemy_pos - self.defense_pos)), "proj_fired": 0.0, "mag_in": float(np.nan), "mag_exec": float(np.nan)}
        
        action = np.clip(action, -1.0, 1.0).astype(np.float32)
        action_in = action.copy()
        mag_in = float(np.linalg.norm(action_in))
        proj_fired = 0.0
        if not skip_projection and mag_in > 1.0:
            action = action_in / mag_in
            proj_fired = 1.0
        mag_exec = float(np.linalg.norm(action))
        
        # Update episode trackers
        self.ep_max_action_mag = max(self.ep_max_action_mag, mag_exec)

        dist_before = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        
        # --- Shaping: closest-approach potential BEFORE transition ---
        phi_before, dstar_before, tstar_before = self._phi_closest_approach()
        terminated = False
        truncated = False
        event = "running"
        
        for _ in range(self.n_substeps):
            dt = self.dt_sim
            enemy_pos_old = self.enemy_pos.copy()
            defense_pos_old = self.defense_pos.copy()
            
            forward, right, up = self._compute_lateral_basis(self.defense_vel)
            g_vec = np.array([0.0, 0.0, -self.g], dtype=np.float32)
            # Direct commanded lateral accel (no lag/jerk)
            a_lat = (action[0] * self.a_max) * right + (action[1] * self.a_max) * up
            self.defense_vel += (a_lat + g_vec) * dt
            self.defense_pos += self.defense_vel * dt
            self.defense_x, self.defense_y, self.defense_z = self.defense_pos
            
            # Enemy missile: pure ballistic (gravity only)
            self.enemy_vel += g_vec * dt
            self.enemy_pos += self.enemy_vel * dt
            self.enemy_x, self.enemy_y, self.enemy_z = self.enemy_pos
            self.t += dt
            
            r0 = enemy_pos_old - defense_pos_old
            r1 = self.enemy_pos - self.defense_pos
            if self._segment_sphere_intersect(r0, r1, self.collision_radius):
                self.success = True
                terminated = True
                self.done = True
                event = "hit"
                self.time_to_hit = float(self.t)
                self.terminal_event = "hit"
                break
            
            dist = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
            if dist > self.max_distance:
                truncated = True
                self.done = True
                event = "diverged"
                self.terminal_event = "diverged"
                break
            if self.defense_pos[2] < 0:
                terminated = True
                self.done = True
                event = "defense_ground"
                self.terminal_event = "defense_ground"
                break
            if self.enemy_pos[2] < 0:
                terminated = True
                self.done = True
                event = "enemy_ground"
                self.terminal_event = "enemy_ground"
                break
            if self.t >= self.t_max:
                truncated = True
                self.done = True
                event = "timeout"
                self.terminal_event = "timeout"
                break

        obs = self._get_obs()
        dist_after = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        self.min_dist = min(getattr(self, "min_dist", float("inf")), dist_after)
        self.ep_min_dist = min(self.ep_min_dist, float(dist_after))
        
        # Reward: ABS PHI (main) + r_ca (garnish) + tiny closing
        v_scale = 1500.0
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)
        d = float(np.linalg.norm(r)) + 1e-9
        rhat = r / d
        closing = -float(np.dot(rhat, vrel))   # m/s, positive => closing
        r_close = float(np.tanh(closing / v_scale))
        phi_after, dstar_after, tstar_after = self._phi_closest_approach()

        r_phi = self._phi_reward(phi_after, dstar_after)
        r_ca = float(self.w_ca * (phi_after - phi_before))
        r_cl = float(self.w_close * r_close)
        reward = float(self.w_phi * r_phi + r_ca + r_cl)

        if event == "hit":
            reward += self.r_hit_bonus
        elif event == "defense_ground":
            reward -= self.r_def_ground_pen
        elif event == "timeout":
            reward -= self.r_timeout_pen
        elif event == "diverged":
            reward -= self.r_diverged_pen

        final_event = event
        if terminated or truncated:
            self.terminal_event = event

        # --- Geometry diagnostics (reuses r, vrel, d, rhat from reward) ---
        v_dbg = self.defense_vel.astype(np.float64)
        V_dbg = float(np.linalg.norm(v_dbg)) + 1e-9
        vhat_dbg = v_dbg / V_dbg
        cos_alpha = float(np.clip(np.dot(vhat_dbg, rhat), -1.0, 1.0))
        alpha_deg = float(np.rad2deg(np.arccos(cos_alpha)))
        omega_vec = np.cross(r, vrel) / (float(np.dot(r, r)) + 1e-9)
        omega_mag = float(np.linalg.norm(omega_vec))
        fwd_z = float(self._compute_lateral_basis(self.defense_vel)[0][2])

        # Action distortion from projection
        ain = action_in.astype(np.float64)
        aex = action.astype(np.float64)
        ain_n = float(np.linalg.norm(ain))
        aex_n = float(np.linalg.norm(aex))
        act_dir_cos = float(np.dot(ain, aex) / (ain_n * aex_n)) if (ain_n > 1e-9 and aex_n > 1e-9) else np.nan
        da = aex - ain

        reward_terms = {
            "r_total": float(reward),
            "r_phi": float(r_phi),
            "r_ca": float(r_ca),
            "r_close_raw": float(r_close),
            "r_close": float(r_cl),
            "phi": float(phi_after),
            "dstar": float(dstar_after),
            "closing": float(closing),
        }

        info = {
            "event": final_event,
            "t": float(self.t),
            "dist": float(dist_after),
            "closing": float(closing),
            "phi": float(phi_after),
            "r_phi": float(r_phi),
            "r_ca": float(r_ca),
            "r_close_raw": float(r_close),
            "r_close": float(r_cl),
            "proj_fired": float(proj_fired),
            "mag_in": float(mag_in),
            "mag_exec": float(mag_exec),
            "alpha_deg": alpha_deg,
            "omega_mag": omega_mag,
            "phi_before": float(phi_before),
            "phi_after": float(phi_after),
            "dstar_before": float(dstar_before),
            "dstar_after": float(dstar_after),
            "tstar_before": float(tstar_before),
            "tstar_after": float(tstar_after),
            "defense_z": float(self.defense_pos[2]),
            "defense_vz": float(self.defense_vel[2]),
            "defense_speed": V_dbg,
            "forward_z": fwd_z,
            "act_dir_cos": float(act_dir_cos) if np.isfinite(act_dir_cos) else np.nan,
            "da0": float(da[0]),
            "da1": float(da[1]),
            "reward_terms": reward_terms,
        }
        if terminated or truncated:
            info["min_dist"] = float(self.ep_min_dist)
            info["max_action_mag"] = float(self.ep_max_action_mag)
            info["time_to_hit"] = float(self.time_to_hit) if self.time_to_hit is not None else np.nan
        return obs, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None: self.np_random = np.random.RandomState(seed)
        self.done = False
        self.success = False
        self.t = 0.0
        self.generate_enemy_missile()
        self.generate_defense_missile()
        
        self.defense_vel = np.array([
            self.defense_initial_velocity * np.cos(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_theta)
        ], dtype=np.float32)
        
        self.enemy_vel = np.array([
            self.enemy_initial_velocity * np.cos(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_theta)
        ], dtype=np.float32)
        
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        d0 = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        self.min_dist = d0
        self.ep_min_dist = d0
        self.ep_max_action_mag = 0.0
        self.time_to_hit = None
        self.terminal_event = "running"
        
        return self._get_obs(), {}


# Same physics, no unit-ball projection (for A/B experiment)
class missile_interception_3d_no_proj(missile_interception_3d):
    def step(self, action):
        action = np.clip(action, -1.0, 1.0).astype(np.float32)
        return super().step(action, skip_projection=True)


# ==========================================
# EVALUATION FUNCTION (separate from training reward)
# ==========================================
def evaluate_policy(env, policy_fn, n_episodes=100, seed0=0):
    """
    Evaluate a policy and return episode-level metrics.
    This is separate from training reward - these metrics are what you actually care about.
    
    Args:
        env: missile_interception_3d environment instance
        policy_fn: Function that takes obs and returns action
        n_episodes: Number of episodes to evaluate
        seed0: Starting seed (episodes use seed0, seed0+1, ..., seed0+n_episodes-1)
    
    Returns:
        summary: Dict with aggregated metrics (hit_rate, min_dist stats, etc.)
        metrics: Dict with raw episode data
    """
    metrics = {
        "hits": 0,
        "ground_defense": 0,
        "ground_enemy": 0,
        "diverged": 0,
        "timeout": 0,
        "min_dist_list": [],
        "time_to_hit_list": [],
        "max_g_list": [],
    }

    for i in range(n_episodes):
        obs, _ = env.reset(seed=seed0 + i)
        done = False

        while not done:
            action = policy_fn(obs)  # your PPO policy OR env.calculate_pronav() baseline
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

        event = info["event"]
        metrics["min_dist_list"].append(env.ep_min_dist)
        metrics["max_g_list"].append(env.ep_max_action_mag)

        if event == "hit":
            metrics["hits"] += 1
            if env.time_to_hit is not None:
                metrics["time_to_hit_list"].append(env.time_to_hit)
        elif event == "defense_ground":
            metrics["ground_defense"] += 1
        elif event == "enemy_ground":
            metrics["ground_enemy"] += 1
        elif event == "diverged":
            metrics["diverged"] += 1
        elif event == "timeout":
            metrics["timeout"] += 1

    hit_rate = metrics["hits"] / n_episodes

    summary = {
        "hit_rate": hit_rate,
        "min_dist_mean": float(np.mean(metrics["min_dist_list"])),
        "min_dist_p50": float(np.median(metrics["min_dist_list"])),
        "min_dist_p10": float(np.percentile(metrics["min_dist_list"], 10)),
        "time_to_hit_mean": float(np.mean(metrics["time_to_hit_list"])) if metrics["time_to_hit_list"] else None,
        "max_g_mean": float(np.mean(metrics["max_g_list"])),
        "violations": {
            "defense_ground": metrics["ground_defense"],
            "enemy_ground": metrics["ground_enemy"],
            "diverged": metrics["diverged"],
            "timeout": metrics["timeout"],
        }
    }
    return summary, metrics




# def run_baseline():
#     env = missile_interception_3d()
#     outcomes = []
#     min_distances = []
#     action_loads = [] # Track if we are saturating (maxing out fins)

#     N_EPISODES = 50
#     print(f"Running {N_EPISODES} episodes of Augmented ProNav...")

#     for i in range(N_EPISODES):
#         obs, _ = env.reset(seed=i)
#         done = False
#         ep_actions = []

#         while not done:
#             # 1. Ask ProNav for the move
#             action = env.calculate_pronav()
            
#             # 2. Track how hard it's pushing (0.0 to 1.0)
#             mag = np.linalg.norm(action)
#             ep_actions.append(mag)

#             obs, reward, terminated, truncated, info = env.step(action)
#             done = terminated or truncated
        
#         outcomes.append(info['event'])
#         min_distances.append(env.ep_min_dist)
#         avg_load = np.mean(ep_actions)
#         action_loads.append(avg_load)

#         print(f"Ep {i+1:02d} | Res: {info['event']:<14} | Min Dist: {env.ep_min_dist:.1f} m | Avg G-Load: {avg_load*100:.1f}%")

#     # Final Stats
#     hits = outcomes.count("hit")
#     non_hit = [d for d, e in zip(min_distances, outcomes) if e != "hit"]
#     print("\n--- SUMMARY ---")
#     print(f"Hit Rate: {hits}/{N_EPISODES} ({hits/N_EPISODES*100:.1f}%)")
#     print("Average Miss Distance (Non-hits):", f"{np.mean(non_hit):.2f} m" if non_hit else "N/A (all hits)")
#     print(f"Average G-Loading: {np.mean(action_loads)*100:.1f}% (If >90%, missile is physically too weak)")

# if __name__ == "__main__":
#     run_baseline()


# def run_pronav_phi_debug(n_episodes=20, seed0=0):
#     env = missile_interception_3d()
#     print(f"Running {n_episodes} episodes of ProNav + phi tracking...")
#     for i in range(n_episodes):
#         obs, _ = env.reset(seed=seed0 + i)
#         done = False
#         phi_hist = []
#         dstar_hist = []
#         while not done:
#             action = env.calculate_pronav()
#             obs, rew, terminated, truncated, info = env.step(action)
#             done = terminated or truncated
#             phi_hist.append(info.get("phi", np.nan))
#             dstar_hist.append(info.get("dstar_after", np.nan))
#         phi_best = float(np.nanmax(phi_hist))
#         dstar_min = float(np.nanmin(dstar_hist))
#         print(f"ep {i:02d} | event={info['event']:<14} | min_dstar={dstar_min:8.1f} m | best_phi={phi_best:+.6f}")

# run_pronav_phi_debug(20, seed0=0)

In [18]:
import numpy as np

# -------------------------
# NEW reward terms (recomputed from env state)
# -------------------------
def phi_reward(dstar, d0=50_000.0, p=2.0):
    # rational shaping: 1/(1+(dstar/d0)^p), always in (0,1], no underflow
    x = float(dstar) / d0
    return float(1.0 / (1.0 + x**p))

def reward_terms_from_state(env, phi_before,
                            w_phi=1.0, w_ca=0.1, w_close=0.002,
                            v_scale=1500.0):
    # geometry
    r = (env.enemy_pos - env.defense_pos).astype(np.float64)
    vrel = (env.enemy_vel - env.defense_vel).astype(np.float64)
    d = float(np.linalg.norm(r)) + 1e-9
    rhat = r / d
    closing = -float(np.dot(rhat, vrel))  # positive => closing
    r_close_raw = float(np.tanh(closing / v_scale))

    phi_after, dstar_after, tstar_after = env._phi_closest_approach()

    r_phi = w_phi * phi_reward(dstar_after)
    r_ca  = w_ca * (phi_after - phi_before)
    r_cl  = w_close * r_close_raw

    r_total = float(r_phi + r_ca + r_cl)

    return {
        "r_total": r_total,
        "r_phi": float(r_phi),
        "r_ca": float(r_ca),
        "r_close": float(r_cl),
        "r_close_raw": float(r_close_raw),
        "closing": float(closing),
        "phi": float(phi_after),
        "dstar": float(dstar_after),
        "tstar": float(tstar_after),
    }

# -------------------------
# Run one episode, log per-step terms
# -------------------------
def rollout_text(env, policy_fn, seed,
                 max_steps=8000,
                 cfg=None,
                 terminal_cfg=None,
                 show_first_k=12):
    if cfg is None:
        cfg = dict(w_phi=1.0, w_ca=0.1, w_close=0.002, v_scale=1500.0)
    if terminal_cfg is None:
        terminal_cfg = dict(hit_bonus=2.0, def_ground_pen=2.0, timeout_pen=0.0, diverged_pen=1.0)

    obs, _ = env.reset(seed=seed)
    done = False
    t = 0

    sums = dict(r_phi=0.0, r_ca=0.0, r_close=0.0, r_total=0.0)
    best_phi = -np.inf
    min_dstar = np.inf
    min_dist = np.inf

    first_rows = []

    final_event = None
    terminal_bonus = 0.0

    while (not done) and (t < max_steps):
        phi_before, dstar_before, tstar_before = env._phi_closest_approach()

        a = policy_fn(obs)
        obs, rew_env, terminated, truncated, info = env.step(a)
        done = terminated or truncated

        terms = reward_terms_from_state(env, phi_before, **cfg)

        # terminal shaping (match your env.step logic)
        ev = info.get("event", "unknown")
        final_event = ev
        if done:
            if ev == "hit":
                terminal_bonus = terminal_cfg["hit_bonus"]
            elif ev == "defense_ground":
                terminal_bonus = -terminal_cfg["def_ground_pen"]
            elif ev == "timeout":
                terminal_bonus = -terminal_cfg["timeout_pen"]
            elif ev == "diverged":
                terminal_bonus = -terminal_cfg["diverged_pen"]
            else:
                terminal_bonus = 0.0

        # track sums (NOT including terminal bonus yet)
        for k in ("r_phi", "r_ca", "r_close", "r_total"):
            sums[k] += terms[k]

        # track episode metrics
        dist = float(np.linalg.norm(env.enemy_pos - env.defense_pos))
        min_dist = min(min_dist, dist)
        best_phi = max(best_phi, terms["phi"])
        min_dstar = min(min_dstar, terms["dstar"])

        if t < show_first_k:
            first_rows.append({
                "t": t,
                "phi": terms["phi"],
                "dstar": terms["dstar"],
                "closing": terms["closing"],
                "r_phi": terms["r_phi"],
                "r_ca": terms["r_ca"],
                "r_cl": terms["r_close"],
                "r_step": terms["r_total"],
                "event": ev,
            })

        t += 1

    # add terminal bonus at the end (like your env does)
    ret_no_terminal = sums["r_total"]
    ret_with_terminal = ret_no_terminal + terminal_bonus

    out = {
        "seed": seed,
        "steps": t,
        "final_event": final_event,
        "min_dist": min_dist,
        "min_dstar": min_dstar,
        "best_phi": best_phi,
        "sum_r_phi": sums["r_phi"],
        "sum_r_ca": sums["r_ca"],
        "sum_r_close": sums["r_close"],
        "return_no_terminal": ret_no_terminal,
        "terminal_bonus": terminal_bonus,
        "return_total": ret_with_terminal,
        "first_rows": first_rows,
    }
    return out

# -------------------------
# Multi-seed runner
# -------------------------
def run_many(env_ctor, policy_kind="pronav", seeds=(0,1,2,3,4),
             cfg=None, terminal_cfg=None, show_first_k=12):
    def make_policy(env):
        if policy_kind == "pronav":
            return lambda obs: env.calculate_pronav()
        elif policy_kind == "random":
            rng = np.random.RandomState(123)
            return lambda obs: rng.uniform(-1.0, 1.0, size=(2,)).astype(np.float32)
        else:
            raise ValueError("policy_kind must be 'pronav' or 'random'")

    results = []
    for s in seeds:
        env = env_ctor()
        pi = make_policy(env)
        r = rollout_text(env, pi, seed=s, cfg=cfg, terminal_cfg=terminal_cfg, show_first_k=show_first_k)
        results.append(r)

        # print per-episode summary
        print("\n==============================")
        print(f"seed={r['seed']} | event={r['final_event']} | steps={r['steps']}")
        print(f"min_dist={r['min_dist']:.3f} | min_dstar={r['min_dstar']:.6f} | best_phi={r['best_phi']:.6f}")
        print("---- return decomposition ----")
        print(f"sum(r_phi)   = {r['sum_r_phi']:+.6f}")
        print(f"sum(r_ca)    = {r['sum_r_ca']:+.6f}")
        print(f"sum(r_close) = {r['sum_r_close']:+.6f}")
        print(f"return(no terminal) = {r['return_no_terminal']:+.6f}")
        print(f"terminal_bonus      = {r['terminal_bonus']:+.6f}")
        print(f"return(total)       = {r['return_total']:+.6f}")

        # early-step table
        print("\nfirst steps (debug):")
        print(" t |    phi      dstar      closing |   r_phi    r_ca    r_cl   r_step | event")
        print("-------------------------------------------------------------------------------")
        for row in r["first_rows"]:
            print(f"{row['t']:2d} | {row['phi']:+.6f} {row['dstar']:9.1f} {row['closing']:9.1f} |"
                  f" {row['r_phi']:.4f} {row['r_ca']:+.6f} {row['r_cl']:+.6f} {row['r_step']:+.6f} | {row['event']}")

    return results

# -------------------------
# USAGE
# -------------------------
# IMPORTANT: replace this with your actual env class name in scope
# from your_module import missile_interception_3d

CFG = dict(w_phi=1.0, w_ca=0.1, w_close=0.002, v_scale=1500.0)
TERM = dict(hit_bonus=2.0, def_ground_pen=2.0, timeout_pen=0.0, diverged_pen=1.0)

# run on pronav for a few seeds
results = run_many(
    env_ctor=missile_interception_3d,
    policy_kind="pronav",
    seeds=(0,1,2,3,4),
    cfg=CFG,
    terminal_cfg=TERM,
    show_first_k=12
)


seed=0 | event=hit | steps=385
min_dist=122.826 | min_dstar=0.002116 | best_phi=-0.000000
---- return decomposition ----
sum(r_phi)   = +350.232038
sum(r_ca)    = +0.678169
sum(r_close) = +0.757890
return(no terminal) = +351.668098
terminal_bonus      = +2.000000
return(total)       = +353.668098

first steps (debug):
 t |    phi      dstar      closing |   r_phi    r_ca    r_cl   r_step | event
-------------------------------------------------------------------------------
 0 | -6.678046   66780.5    2792.9 | 0.3592 +0.010364 +0.001906 +0.371485 | running
 1 | -6.575010   65750.1    2813.2 | 0.3664 +0.010304 +0.001908 +0.378615 | running
 2 | -6.472664   64726.6    2833.2 | 0.3737 +0.010235 +0.001911 +0.385863 | running
 3 | -6.371083   63710.8    2853.0 | 0.3812 +0.010158 +0.001913 +0.393223 | running
 4 | -6.270335   62703.4    2872.5 | 0.3887 +0.010075 +0.001915 +0.400689 | running
 5 | -6.170490   61704.9    2891.7 | 0.3964 +0.009984 +0.001917 +0.408255 | running
 6 | -6.071610  

In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random

# -------------------------
# Seeding utilities
# -------------------------
def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def reset_env_seeded(env, seed: int):
    obs, info = env.reset(seed=seed)
    try:
        env.action_space.seed(seed)
    except Exception:
        pass
    try:
        env.observation_space.seed(seed)
    except Exception:
        pass
    return obs, info

# -------------------------
# Policy + Value network
# -------------------------
class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden=128):
        super().__init__()
        self.pi = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, act_dim),
        )
        self.v = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, 1),
        )
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def dist(self, obs):
        mu = self.pi(obs)
        std = torch.exp(self.log_std).clamp(1e-4, 10.0)
        return torch.distributions.Normal(mu, std)

    def act(self, obs):
        dist = self.dist(obs)
        a = dist.sample()
        logp = dist.log_prob(a).sum(-1)
        v = self.v(obs).squeeze(-1)
        return a, logp, v

    def eval_actions(self, obs, act):
        dist = self.dist(obs)
        logp = dist.log_prob(act).sum(-1)
        ent = dist.entropy().sum(-1)
        v = self.v(obs).squeeze(-1)
        return logp, ent, v

# -------------------------
# GAE / returns
# -------------------------
def compute_gae(rews, vals, dones, gamma=0.999, lam=0.95):
    T = len(rews)
    adv = np.zeros(T, dtype=np.float32)
    gae = 0.0
    for t in reversed(range(T)):
        nonterminal = 1.0 - dones[t]
        delta = rews[t] + gamma * vals[t+1] * nonterminal - vals[t]
        gae = delta + gamma * lam * nonterminal * gae
        adv[t] = gae
    ret = adv + vals[:-1]
    return adv, ret

def env_action_from_policy_action(a_np):
    a = np.clip(a_np, -1.0, 1.0).astype(np.float32)
    return a

# -------------------------
# Rollout collection (with proj/mag instrumentation)
# -------------------------
@torch.no_grad()
def collect_rollout(env, ac, T=1024, device="cpu", reset_seed=None):
    if reset_seed is None:
        obs, _ = env.reset()
    else:
        obs, _ = reset_env_seeded(env, reset_seed)
    obs_buf = []
    act_buf = []
    logp_buf = []
    rew_buf = []
    done_buf = []
    val_buf = []
    proj_buf = []
    mag_in_buf = []
    mag_exec_buf = []

    r_phi_buf = []
    r_ca_buf = []
    r_close_buf = []

    for _ in range(T):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        a_t, logp_t, v_t = ac.act(obs_t)
        a_np = a_t.squeeze(0).cpu().numpy()
        a_env = env_action_from_policy_action(a_np)

        obs2, rew, terminated, truncated, info = env.step(a_env)
        done = float(terminated or truncated)

        rt = info.get("reward_terms", {}) if isinstance(info, dict) else {}
        r_phi = float(rt.get("r_phi", 0.0)) if isinstance(rt, dict) else 0.0
        r_ca = float(rt.get("r_ca", 0.0)) if isinstance(rt, dict) else 0.0
        r_close = float(rt.get("r_close", 0.0)) if isinstance(rt, dict) else 0.0

        obs_buf.append(obs)
        act_buf.append(a_np)
        logp_buf.append(logp_t.item())
        rew_buf.append(float(rew))
        done_buf.append(done)
        val_buf.append(v_t.item())
        proj_buf.append(info.get("proj_fired", 0.0))
        mag_in_buf.append(info.get("mag_in", np.nan))
        mag_exec_buf.append(info.get("mag_exec", np.nan))
        r_phi_buf.append(r_phi)
        r_ca_buf.append(r_ca)
        r_close_buf.append(r_close)

        obs = obs2
        if done:
            if reset_seed is None:
                obs, _ = env.reset()
            else:
                reset_seed = reset_seed + 1
                obs, _ = reset_env_seeded(env, reset_seed)

    obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
    v_last = ac.v(obs_t).item()
    val_buf.append(v_last)

    return {
        "obs": np.asarray(obs_buf, dtype=np.float32),
        "act": np.asarray(act_buf, dtype=np.float32),
        "logp": np.asarray(logp_buf, dtype=np.float32),
        "rew": np.asarray(rew_buf, dtype=np.float32),
        "done": np.asarray(done_buf, dtype=np.float32),
        "val": np.asarray(val_buf, dtype=np.float32),
        "proj": np.asarray(proj_buf, dtype=np.float32),
        "mag_in": np.asarray(mag_in_buf, dtype=np.float32),
        "mag_exec": np.asarray(mag_exec_buf, dtype=np.float32),
        "r_phi": np.asarray(r_phi_buf, dtype=np.float32),
        "r_ca": np.asarray(r_ca_buf, dtype=np.float32),
        "r_close": np.asarray(r_close_buf, dtype=np.float32),
    }

# -------------------------
# PPO update
# -------------------------
def ppo_update(ac, opt, data, *, gamma=0.999, lam=0.95, clip_eps=0.2, vf_coef=0.5, ent_coef=0.0, epochs=10, mb=256, device="cpu"):
    adv, ret = compute_gae(data["rew"], data["val"], data["done"], gamma=gamma, lam=lam)
    adv = (adv - adv.mean()) / (adv.std() + 1e-8)
    obs = torch.tensor(data["obs"], dtype=torch.float32, device=device)
    act = torch.tensor(data["act"], dtype=torch.float32, device=device)
    logp_old = torch.tensor(data["logp"], dtype=torch.float32, device=device)
    adv_t = torch.tensor(adv, dtype=torch.float32, device=device)
    ret_t = torch.tensor(ret, dtype=torch.float32, device=device)
    N = obs.shape[0]
    idx = np.arange(N)
    last = {}
    for _ in range(epochs):
        np.random.shuffle(idx)
        for start in range(0, N, mb):
            j = idx[start:start+mb]
            logp, ent, v = ac.eval_actions(obs[j], act[j])
            ratio = torch.exp(logp - logp_old[j])
            surr1 = ratio * adv_t[j]
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_t[j]
            pi_loss = -(torch.min(surr1, surr2)).mean()
            v_loss = 0.5 * (ret_t[j] - v).pow(2).mean()
            ent_loss = -ent.mean()
            loss = pi_loss + vf_coef * v_loss + ent_coef * ent_loss
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(ac.parameters(), 1.0)
            opt.step()
            with torch.no_grad():
                approx_kl = (logp_old[j] - logp).mean().item()
                clipfrac = (torch.abs(ratio - 1.0) > clip_eps).float().mean().item()
                last = {"pi_loss": pi_loss.item(), "v_loss": v_loss.item(), "ent": ent.mean().item(), "approx_kl": approx_kl, "clipfrac": clipfrac, "adv_std": float(adv_t.std().item())}
    return last

# -------------------------
# Training loop (instrumented print)
# -------------------------
def _reward_decomp_stats(data):
    rew = data["rew"]
    r_phi = data.get("r_phi", np.zeros_like(rew))
    r_ca = data.get("r_ca", np.zeros_like(rew))
    r_close = data.get("r_close", np.zeros_like(rew))
    r_other = rew - r_phi - r_ca - r_close
    eps = 1e-12
    abs_total = float(np.mean(np.abs(rew)) + eps)
    return {
        "r_phi_mean": float(np.mean(r_phi)),
        "r_ca_mean": float(np.mean(r_ca)),
        "r_close_mean": float(np.mean(r_close)),
        "r_other_mean": float(np.mean(r_other)),
        "phi_share_abs": float(np.mean(np.abs(r_phi)) / abs_total),
        "ca_share_abs": float(np.mean(np.abs(r_ca)) / abs_total),
        "close_share_abs": float(np.mean(np.abs(r_close)) / abs_total),
        "other_share_abs": float(np.mean(np.abs(r_other)) / abs_total),
    }

def train_micro_ppo(env, total_updates=50, T=1024, lr=3e-4, device="cpu", seed=None):
    if seed is not None:
        set_global_seed(seed)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    ac = ActorCritic(obs_dim, act_dim).to(device)
    opt = optim.Adam(ac.parameters(), lr=lr)
    for u in range(total_updates):
        rollout_seed = None if seed is None else (seed * 100000 + u)
        data = collect_rollout(env, ac, T=T, device=device, reset_seed=rollout_seed)
        stats = ppo_update(ac, opt, data, device=device)
        proj_rate = float(data["proj"].mean())
        mag_in_mean = float(np.nanmean(data["mag_in"]))
        mag_exec_mean = float(np.nanmean(data["mag_exec"]))
        rd = _reward_decomp_stats(data)
        print(
            f"upd {u:03d} | rew={data['rew'].mean():+.4f} "
            f"| rphi={rd['r_phi_mean']:+.4f} | rca={rd['r_ca_mean']:+.4f} | rcl={rd['r_close_mean']:+.4f} | roth={rd['r_other_mean']:+.4f} "
            f"| abs(phi/ca/cl/oth)=({rd['phi_share_abs']:.2f}/{rd['ca_share_abs']:.2f}/{rd['close_share_abs']:.2f}/{rd['other_share_abs']:.2f}) "
            f"| proj={proj_rate:.3f} | mag={mag_exec_mean:.3f} "
            f"| kl={stats['approx_kl']:+.2e} | clip={stats['clipfrac']:.3f} | ent={stats['ent']:+.3f}"
        )
    return ac

# -------------------------
# Eval by physics (min_dist, hit_rate, events) — for A/B experiment
# -------------------------
@torch.no_grad()
def evaluate_policy_ac(env, ac, episodes=20, device="cpu", seed=None):
    if seed is not None:
        set_global_seed(seed)
    min_dists = []
    hits = 0
    events = {}
    for ep in range(episodes):
        if seed is None:
            obs, _ = env.reset()
        else:
            obs, _ = reset_env_seeded(env, seed + ep)
        done = False
        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            dist = ac.dist(obs_t)
            action = dist.mean
            a_np = np.clip(action.squeeze(0).cpu().numpy(), -1.0, 1.0).astype(np.float32)
            obs, rew, terminated, truncated, info = env.step(a_np)
            done = terminated or truncated
            if done:
                event = info.get("event", "unknown")
                events[event] = events.get(event, 0) + 1
                if event == "hit":
                    hits += 1
        min_dists.append(float(env.ep_min_dist))
    return {
        "hit_rate": hits / episodes,
        "min_dist_p50": float(np.percentile(min_dists, 50)),
        "min_dist_p10": float(np.percentile(min_dists, 10)),
        "events": events,
    }


def run_ab_experiment(env_class=missile_interception_3d, updates=30, T=2048, eval_episodes=30, device="cpu", seed=0):
    print(f"\n===== SEED {seed} =====")
    set_global_seed(seed)
    print("=== A: WITH PROJECTION ===")
    env_a = env_class()
    ac_a = train_micro_ppo(env_a, total_updates=updates, T=T, device=device, seed=seed)
    eval_a = evaluate_policy_ac(env_a, ac_a, episodes=eval_episodes, device=device, seed=seed + 10000)
    print("Eval A:", eval_a)
    print("\n=== B: WITHOUT PROJECTION ===")
    set_global_seed(seed)
    env_b = missile_interception_3d_no_proj()
    ac_b = train_micro_ppo(env_b, total_updates=updates, T=T, device=device, seed=seed)
    eval_b = evaluate_policy_ac(env_b, ac_b, episodes=eval_episodes, device=device, seed=seed + 10000)
    print("Eval B:", eval_b)
    return eval_a, eval_b, ac_a, ac_b


# Debug run: few updates, longer rollouts for stable GAE
env_ppo = missile_interception_3d()
# train_micro_ppo(env_ppo, total_updates=10, T=2048, lr=3e-4, device="cpu")

# -------------------------
# Diagnostic: episode turning-point detector
# -------------------------
@torch.no_grad()
def debug_episode_turning_point(env, ac, device="cpu", seed=None, patience=5, print_window=8):
    if seed is None:
        obs, _ = env.reset()
    else:
        obs, _ = reset_env_seeded(env, seed)

    traj = []
    done = False
    t = 0

    while not done:
        obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        action = ac.dist(obs_t).mean
        a_np = np.clip(action.squeeze(0).cpu().numpy(), -1.0, 1.0).astype(np.float32)

        obs, rew, terminated, truncated, info = env.step(a_np)
        done = terminated or truncated

        rt = info.get("reward_terms", {})
        traj.append({
            "t": t,
            "rew": float(rew),
            "dist": float(info.get("dist", np.nan)),
            "closing": float(info.get("closing", np.nan)),
            "event": info.get("event", None),
            "proj_fired": float(info.get("proj_fired", np.nan)),
            "mag_in": float(info.get("mag_in", np.nan)),
            "mag_exec": float(info.get("mag_exec", np.nan)),
            "r_total": float(rt.get("r_total", rew)) if isinstance(rt, dict) else float(rew),
            "r_phi": float(rt.get("r_phi", np.nan)) if isinstance(rt, dict) else np.nan,
            "r_ca": float(rt.get("r_ca", np.nan)) if isinstance(rt, dict) else np.nan,
            "r_close": float(rt.get("r_close", np.nan)) if isinstance(rt, dict) else np.nan,
            "alpha_deg": float(info.get("alpha_deg", np.nan)),
            "omega_mag": float(info.get("omega_mag", np.nan)),
            "phi_before": float(info.get("phi_before", np.nan)),
            "phi_after": float(info.get("phi_after", np.nan)),
            "dstar_before": float(info.get("dstar_before", np.nan)),
            "dstar_after": float(info.get("dstar_after", np.nan)),
            "tstar_before": float(info.get("tstar_before", np.nan)),
            "tstar_after": float(info.get("tstar_after", np.nan)),
            "defense_z": float(info.get("defense_z", np.nan)),
            "defense_vz": float(info.get("defense_vz", np.nan)),
            "defense_speed": float(info.get("defense_speed", np.nan)),
            "forward_z": float(info.get("forward_z", np.nan)),
            "act_dir_cos": float(info.get("act_dir_cos", np.nan)),
            "da0": float(info.get("da0", np.nan)),
            "da1": float(info.get("da1", np.nan)),
        })
        t += 1

    d = np.array([x["dist"] for x in traj], dtype=np.float32)
    if np.isnan(d).all():
        print("No info['dist'] found. Add info['dist'] in env.step() first.")
        print("Final event:", traj[-1]["event"] if len(traj) else "unknown")
        return traj

    best_so_far = np.minimum.accumulate(d)
    turning_idx = None
    for i in range(1, len(d) - patience):
        if d[i] <= best_so_far[i] + 1e-9:
            future = d[i+1:i+1+patience]
            if np.all(future > d[i]):
                turning_idx = i
                break

    # --- Episode-level summary metrics ---
    t_min = int(np.nanargmin(d))
    xm = traj[t_min]
    n_steps = len(traj)
    closings = np.array([x["closing"] for x in traj], dtype=np.float32)
    rews = np.array([x["rew"] for x in traj], dtype=np.float32)
    projs = np.array([x["proj_fired"] for x in traj], dtype=np.float32)
    mags = np.array([x["mag_exec"] for x in traj], dtype=np.float32)

    r_phi_arr = np.array([x["r_phi"] for x in traj], dtype=np.float32)
    r_ca_arr = np.array([x["r_ca"] for x in traj], dtype=np.float32)
    r_close_arr = np.array([x["r_close"] for x in traj], dtype=np.float32)
    abs_rew_mean = float(np.nanmean(np.abs(rews)) + 1e-12)
    abs_phi_share = float(np.nanmean(np.abs(r_phi_arr)) / abs_rew_mean)
    abs_ca_share = float(np.nanmean(np.abs(r_ca_arr)) / abs_rew_mean)
    abs_close_share = float(np.nanmean(np.abs(r_close_arr)) / abs_rew_mean)

    print(f"\n===== EPISODE DIAGNOSTIC (seed={seed}) =====")
    print(f"steps={n_steps} | final_event={traj[-1]['event']} | min_dist={np.nanmin(d):.1f} @ t={t_min}")
    print(f"  at min_dist: alpha={xm['alpha_deg']:.1f}° closing={xm['closing']:.1f} m/s dstar={xm['dstar_after']:.1f}")
    print(f"  reward_decomp: mean(r_phi)={np.nanmean(r_phi_arr):+.4f} | mean(r_ca)={np.nanmean(r_ca_arr):+.4f} | mean(r_close)={np.nanmean(r_close_arr):+.4f} | abs_share(phi/ca/cl)=({abs_phi_share:.2f}/{abs_ca_share:.2f}/{abs_close_share:.2f})")
    print(f"  fracs: closing>0 {np.nanmean(closings>0)*100:.0f}% | rew>0 {np.nanmean(rews>0)*100:.0f}% | proj {np.nanmean(projs)*100:.0f}%")
    print(f"  action: mean_mag={np.nanmean(mags):.3f} max_mag={np.nanmax(mags):.3f}")

    def _print_window(lo, hi, label, highlight_idx=None):
        print(f"\n--- {label} (t={lo}..{hi-1}) ---")
        print(f" {'t':>3s} | {'dist':>9s} | {'closing':>8s} | {'alpha':>6s} | {'dstar':>8s} | {'rew':>7s} | {'r_phi':>7s} | {'r_ca':>7s} | {'r_cl':>7s} | {'proj':>4s} | {'mag_ex':>6s}")
        print("-" * 105)
        for k in range(lo, hi):
            x = traj[k]
            mark = " <--" if k == highlight_idx else ""
            def _f(v, w, p): return f"{v:{w}.{p}f}" if np.isfinite(v) else " " * (w - 3) + "nan"
            print(
                f"{x['t']:3d} | "
                f"{_f(x['dist'],9,1)} | {_f(x['closing'],8,1)} | {_f(x['alpha_deg'],6,1)} | "
                f"{_f(x['dstar_after'],8,1)} | {_f(x['rew'],7,4)} | "
                f"{_f(x['r_phi'],7,4)} | {_f(x['r_ca'],7,4)} | {_f(x['r_close'],7,4)} | "
                f"{x['proj_fired']:4.0f} | {_f(x['mag_exec'],6,3)}{mark}"
            )

    if turning_idx is None:
        print("No clear sustained turning point detected.")
    else:
        print(f"TURNING POINT at t={turning_idx} | dist={traj[turning_idx]['dist']:.1f}")
        lo = max(0, turning_idx - print_window)
        hi = min(n_steps, turning_idx + print_window + 1)
        _print_window(lo, hi, "TURNING POINT", highlight_idx=turning_idx)

    # Terminal window (last steps)
    term_window = min(12, n_steps)
    _print_window(n_steps - term_window, n_steps, "TERMINAL")

    return traj


# A/B experiment:
eval_a, eval_b, ac_a, ac_b = run_ab_experiment(updates=15, T=2048, eval_episodes=20, seed=0)

# Diagnostic: inspect turning points for a few seeds
for s in [0, 1, 2, 3, 4]:
    print(f"\n### A (with projection) seed={s} ###")
    debug_episode_turning_point(missile_interception_3d(), ac_a, seed=s)
    print(f"\n### B (no projection) seed={s} ###")
    debug_episode_turning_point(missile_interception_3d_no_proj(), ac_b, seed=s)


===== SEED 0 =====
=== A: WITH PROJECTION ===
upd 000 | rew_mean=-0.0359 | rew_std=0.0465 | rca=-0.0278 | rcl=-0.4410 | roth=+0.4328 | abs(ca/cl/oth)=(0.76/14.80/14.29) | proj=0.613 | mag_in=0.963 | mag_exec=0.855 | kl=+2.35e-03 | clip=0.074 | ent=+2.826
upd 001 | rew_mean=-0.0456 | rew_std=0.0428 | rca=-0.0358 | rcl=-0.5531 | roth=+0.5434 | abs(ca/cl/oth)=(0.72/14.15/13.72) | proj=0.611 | mag_in=0.965 | mag_exec=0.858 | kl=+1.09e-02 | clip=0.055 | ent=+2.816
upd 002 | rew_mean=-0.0244 | rew_std=0.0537 | rca=-0.0251 | rcl=-0.2361 | roth=+0.2368 | abs(ca/cl/oth)=(0.93/14.83/14.25) | proj=0.591 | mag_in=0.946 | mag_exec=0.845 | kl=+1.10e-03 | clip=0.023 | ent=+2.806
upd 003 | rew_mean=-0.0370 | rew_std=0.0484 | rca=-0.0289 | rcl=-0.4515 | roth=+0.4433 | abs(ca/cl/oth)=(0.82/14.55/14.03) | proj=0.614 | mag_in=0.965 | mag_exec=0.854 | kl=+3.01e-04 | clip=0.027 | ent=+2.799
upd 004 | rew_mean=-0.0341 | rew_std=0.0513 | rca=-0.0277 | rcl=-0.4037 | roth=+0.3974 | abs(ca/cl/oth)=(0.79/14.40/1