In [11]:
import gymnasium as gym
import numpy as np

class missile_interception_3d(gym.Env):
    def __init__(self):
        # 1. Define Action Space (The Joystick: Left/Right, Up/Down)
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2, ), dtype=np.float32)
        
        # 2. Define Observation Space (16D ego-frame, no actuator state)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)

        self.np_random = np.random.RandomState()
        
        # 3. Time Settings
        self.dt_act = 0.1             
        self.n_substeps = 10          
        self.dt_sim = self.dt_act / self.n_substeps 
        self.t_max = 650.0 # REVISAR IF THIS IS ENOF         

        # 4. Physical Limits
        self.a_max = 350.0   # Max G-force (m/s^2) ~35G
        self.g = 9.81        
        self.collision_radius = 150.0  
        self.max_distance = 4_000_000.0 

        # LET'S REMOVE THE "HARD CASES"

        self.p_easy = 1.0                   
        self.range_min = 70_000.0           
        self.range_easy_max = 200_000.0     
        self.range_hard_max = 1_000_000.0   

        self.targetbox_x_min = -15000
        self.targetbox_x_max = 15000
        self.targetbox_y_min = -15000
        self.targetbox_y_max = 15000

        # --- Closest-approach shaping ---
        self.gamma_shape = 0.9999      # set = PPO gamma when training
        self.w_ca = 1.0
        self.dstar_scale = 50_000.0   # meters; tune so r_ca ~ 0.05-0.15 under ProNav

    def generate_enemy_missile(self):
        if self.np_random.rand() < self.p_easy:
            self.range_max_used = self.range_easy_max
        else:
            self.range_max_used = self.range_hard_max

        range_min = self.range_min
        self.attack_target_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.attack_target_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)
        self.enemy_launch_angle = self.np_random.uniform(0, 2 * np.pi)
        self.enemy_theta = self.np_random.uniform(0.523599, 1.0472) 

        self.range_max_used = max(self.range_max_used, range_min + 1.0)
        lower_limit = np.sqrt((range_min * self.g) / np.sin(2 * self.enemy_theta))
        upper_limit = np.sqrt((self.range_max_used * self.g) / np.sin(2 * self.enemy_theta))
        self.enemy_initial_velocity = self.np_random.uniform(lower_limit, upper_limit)

        ground_range = (
            self.enemy_initial_velocity * np.cos(self.enemy_theta)
            * (2 * self.enemy_initial_velocity * np.sin(self.enemy_theta) / self.g)
        )

        self.enemy_launch_x = self.attack_target_x + ground_range * np.cos(self.enemy_launch_angle)
        self.enemy_launch_y = self.attack_target_y + ground_range * np.sin(self.enemy_launch_angle)
        self.enemy_z = 0
        self.enemy_x = self.enemy_launch_x
        self.enemy_y = self.enemy_launch_y
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.enemy_azimuth = (self.enemy_launch_angle + np.pi) % (2 * np.pi)

    def generate_defense_missile(self):
        self.defense_launch_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.defense_launch_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)

        dx = self.enemy_launch_x - self.defense_launch_x
        dy = self.enemy_launch_y - self.defense_launch_y
        az_nominal = np.arctan2(dy, dx)

        # --- Misalignment (domain randomization of initial heading) ---
        # Mixture: most episodes small error, some episodes large az error
        p_misaligned = 0.35  # 35% "hard" starts
        if self.np_random.rand() < p_misaligned:
            # Hard: big azimuth error → strong RIGHT required
            az_noise = self.np_random.uniform(-np.deg2rad(60.0), np.deg2rad(60.0))
        else:
            # Easy: small azimuth error → gentle correction
            az_noise = self.np_random.uniform(-np.deg2rad(10.0), np.deg2rad(10.0))

        self.defense_azimuth = az_nominal + az_noise

        # Elevation noise: avoid always same vertical plane
        theta_nominal = 0.785398  # ~45 deg
        theta_noise_deg = 10.0
        theta_noise = self.np_random.uniform(
            -np.deg2rad(theta_noise_deg),
            +np.deg2rad(theta_noise_deg),
        )
        self.defense_theta = float(np.clip(theta_nominal + theta_noise,
                                           np.deg2rad(10.0),
                                           np.deg2rad(80.0)))
        # ---------------------------------------------------------------

        base_velocity = 3000.0
        if hasattr(self, 'range_max_used'):
            velocity_scale = min(self.range_max_used / self.range_easy_max, 1.5)
            self.defense_initial_velocity = base_velocity * velocity_scale
        else:
            self.defense_initial_velocity = base_velocity

        self.defense_x = self.defense_launch_x
        self.defense_y = self.defense_launch_y
        self.defense_z = 0.0
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
    
    def _smoothstep(self, x: float) -> float:
        """Smooth ramp 0->1 with zero slope at ends, clamps outside [0,1]"""
        x = float(np.clip(x, 0.0, 1.0))
        return x * x * (3.0 - 2.0 * x)
    
    def calculate_pronav(self):
        eps = 1e-9

        # Relative geometry (use float64 for stability)
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        v = self.defense_vel.astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)

        R = float(np.linalg.norm(r)) + eps
        V = float(np.linalg.norm(v)) + eps

        rhat = r / R
        vhat = v / V

        # Heading error alpha = angle between velocity direction and LOS direction
        cosang = float(np.clip(np.dot(vhat, rhat), -1.0, 1.0))
        alpha = float(np.arccos(cosang))  # radians

        # LOS angular rate omega (world frame)
        omega = np.cross(r, vrel) / (float(np.dot(r, r)) + eps)
        omega_mag = float(np.linalg.norm(omega))

        # Closing speed (positive => closing)
        vc = -float(np.dot(r, vrel)) / R

        # --- PN term ---
        N = 3.0
        a_pn = N * vc * np.cross(omega, rhat)  # lateral accel in world frame

        # --- Acquisition term (turn-to-LOS) ---
        # Perpendicular component of LOS relative to forward direction
        rhat_perp = rhat - float(np.dot(rhat, vhat)) * vhat
        nperp = float(np.linalg.norm(rhat_perp))

        if nperp < 1e-8:
            a_acq = np.zeros(3, dtype=np.float64)
        else:
            rhat_perp /= nperp  # unit sideways "turn toward LOS" direction

            # Curvature-based magnitude: ~k * V^2 / R, saturate later via a_max
            k_acq = 5.0  # try 3.0–8.0
            a_acq = k_acq * (V * V / R) * rhat_perp

        # --- Blend weight w: 0 => pure PN, 1 => pure acquisition ---

        # Alpha-based weight (dominant)
        alpha_on   = np.deg2rad(20.0)   # start blending earlier
        alpha_full = np.deg2rad(55.0)

        x_alpha = (alpha - alpha_on) / (alpha_full - alpha_on + eps)
        w_alpha = self._smoothstep(x_alpha)

        # Omega-based modifier (only boosts acquisition when PN is sleepy)
        omega_full = 0.00
        omega_on   = 0.05   # <-- key: less brittle than 0.02

        x_omega = (omega_on - omega_mag) / (omega_on - omega_full + eps)
        w_omega = self._smoothstep(x_omega)

        # Robust combine: alpha dominates; omega can't fully shut it off
        w = w_alpha * (0.25 + 0.75 * w_omega)

        # Optional: if not closing, force strong acquisition
        if vc <= 0.0:
            w = max(w, 0.9)

        a_ideal = (1.0 - w) * a_pn + w * a_acq

        # Project into your lateral control basis (right/up) and normalize by a_max
        # Note: Environment now handles gravity compensation internally,
        # so ProNav outputs desired NET lateral accel (same semantics as PPO)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)
        a_right = float(np.dot(a_ideal, right))
        a_up    = float(np.dot(a_ideal, up))

        action = np.array([a_right / self.a_max, a_up / self.a_max], dtype=np.float32)
        return np.clip(action, -1.0, 1.0)
    
    def _segment_sphere_intersect(self, r0, r1, r_hit):
        dr = r1 - r0
        dr_norm_sq = float(np.dot(dr, dr))
        if dr_norm_sq < 1e-12:
            return float(np.dot(r0, r0)) <= r_hit * r_hit
        s_star = -float(np.dot(r0, dr)) / dr_norm_sq
        s_star = max(0.0, min(1.0, s_star))
        r_closest = r0 + s_star * dr
        return float(np.dot(r_closest, r_closest)) <= r_hit * r_hit
    
    def _phi_closest_approach(self):
        """
        Potential based on closest-approach miss distance under
        constant relative velocity assumption.
        Returns:
            phi, d_star, t_star
        """
        eps = 1e-9

        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)

        vrel_norm_sq = float(np.dot(vrel, vrel)) + eps

        # t* = argmin ||r + vrel t|| for t >= 0
        t_star = -float(np.dot(r, vrel)) / vrel_norm_sq
        t_star = max(0.0, t_star)

        m_star = r + vrel * t_star
        d_star = float(np.linalg.norm(m_star))

        phi = -d_star / (self.dstar_scale + eps)
        return float(phi), d_star, float(t_star)
    
    def _get_obs(self):
        eps = 1e-9

        # World-frame relative state
        r_world = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel_world = (self.enemy_vel - self.defense_vel).astype(np.float64)

        # Local basis from defense velocity (world frame unit vectors)
        forward, right, up = self._compute_lateral_basis(self.defense_vel)

        # ===============================
        # 1) Ego-frame (body-frame) r and vrel
        # ===============================
        r_body = np.array([
            float(np.dot(r_world, forward)),
            float(np.dot(r_world, right)),
            float(np.dot(r_world, up)),
        ], dtype=np.float64)

        vrel_body = np.array([
            float(np.dot(vrel_world, forward)),
            float(np.dot(vrel_world, right)),
            float(np.dot(vrel_world, up)),
        ], dtype=np.float64)

        # Normalize r_body / vrel_body (easy range only)
        pos_scale = float(self.range_easy_max)
        vel_scale = 4000.0

        r_body_n = (r_body / (pos_scale + eps)).astype(np.float32)
        vrel_body_n = (vrel_body / (vel_scale + eps)).astype(np.float32)

        # ===============================
        # 2) Scalar helpers
        # ===============================
        dist = float(np.linalg.norm(r_world)) + 1e-6
        v_close = -float(np.dot(r_world, vrel_world)) / dist  # positive when closing

        dist_n = np.float32(np.clip(dist / 1_000_000.0, 0.0, 4.0))
        vclose_n = np.float32(np.clip(v_close / 3000.0, -2.0, 2.0))
        dist_vclose_feat = np.array([dist_n, vclose_n], dtype=np.float32)

        # Defense own vertical state (keep for ground constraint)
        def_z_n = np.float32(np.clip(self.defense_pos[2] / 100_000.0, -1.0, 2.0))
        def_vz_n = np.float32(np.clip(self.defense_vel[2] / 3000.0, -2.0, 2.0))
        def_state_feat = np.array([def_z_n, def_vz_n], dtype=np.float32)

        # ===============================
        # 4) Keep your geometry features (consistent with ego-frame)
        # ===============================
        dist_body = float(np.linalg.norm(r_body)) + 1e-6

        # LOS lateral projections in body frame
        los_right = float(r_body[1] / dist_body)
        los_up    = float(r_body[2] / dist_body)

        # LOS rate omega in body frame: omega = (r x vrel)/||r||^2
        dist2_body = float(np.dot(r_body, r_body)) + eps
        omega_body = np.cross(r_body, vrel_body) / dist2_body

        omega_right = float(omega_body[1])
        omega_up    = float(omega_body[2])

        omega_scale = 2.0
        omega_right_n = float(np.clip(omega_right / omega_scale, -2.0, 2.0))
        omega_up_n    = float(np.clip(omega_up / omega_scale, -2.0, 2.0))

        geom_feat = np.array([los_right, los_up, omega_right_n, omega_up_n], dtype=np.float32)

        # ===============================
        # 5) NEW: kinematics garnish
        # ===============================
        V_def = float(np.linalg.norm(self.defense_vel))
        V_def_n = np.float32(np.clip(V_def / 3000.0, 0.0, 3.0))  # scale: 3000 m/s baseline
        forward_z = np.float32(float(forward[2]))               # dot(forward, world_up) since world_up=[0,0,1]

        kin_feat = np.array([V_def_n, forward_z], dtype=np.float32)

        # Final obs (16D, no actuator state)
        obs = np.concatenate(
            [r_body_n, vrel_body_n, dist_vclose_feat, def_state_feat, geom_feat, kin_feat],
            axis=0
        ).astype(np.float32)

        return obs

    def _compute_lateral_basis(self, velocity):
        """
        Horizon-stable basis:
          forward = along velocity
          right   = world_up x forward  (horizontal right)
          up      = forward x right     (completes orthonormal frame)
        This keeps 'up' as close to world-up as possible and avoids weird twisting.
        """
        speed = float(np.linalg.norm(velocity))
        if speed < 1.0:
            forward = np.array([1.0, 0.0, 0.0], dtype=np.float32)
        else:
            forward = (velocity / speed).astype(np.float32)

        world_up = np.array([0.0, 0.0, 1.0], dtype=np.float32)

        # right = world_up x forward
        right_raw = np.cross(world_up, forward)
        rnorm = float(np.linalg.norm(right_raw))

        # If forward is near world_up, right_raw ~ 0. Pick a consistent fallback.
        if rnorm < 1e-6:
            # Choose a fixed "north" axis in world XY and build right from that
            # This prevents random spinning when vertical.
            north = np.array([1.0, 0.0, 0.0], dtype=np.float32)
            right_raw = np.cross(north, forward)
            rnorm = float(np.linalg.norm(right_raw))
            if rnorm < 1e-6:
                north = np.array([0.0, 1.0, 0.0], dtype=np.float32)
                right_raw = np.cross(north, forward)
                rnorm = float(np.linalg.norm(right_raw))

        right = (right_raw / (rnorm + 1e-9)).astype(np.float32)

        # up = forward x right (not right x forward)
        up_raw = np.cross(forward, right)
        up = (up_raw / (float(np.linalg.norm(up_raw)) + 1e-9)).astype(np.float32)

        return forward, right, up

    def step(self, action):
        if getattr(self, "done", False):
            return self._get_obs(), 0.0, True, False, {"event": "called_step_after_done", "dist": float(np.linalg.norm(self.enemy_pos - self.defense_pos))}
        
        action = np.clip(action, -1.0, 1.0).astype(np.float32)
        mag = float(np.linalg.norm(action))
        if mag > 1.0:
            action = action / mag
            mag = 1.0
        
        # Update episode trackers
        self.ep_max_action_mag = max(self.ep_max_action_mag, float(mag))

        dist_before = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        
        # --- Shaping: closest-approach potential BEFORE transition ---
        phi_before, dstar_before, tstar_before = self._phi_closest_approach()
        terminated = False
        truncated = False
        event = "running"
        
        for _ in range(self.n_substeps):
            dt = self.dt_sim
            enemy_pos_old = self.enemy_pos.copy()
            defense_pos_old = self.defense_pos.copy()
            
            forward, right, up = self._compute_lateral_basis(self.defense_vel)
            g_vec = np.array([0.0, 0.0, -self.g], dtype=np.float32)
            # Direct commanded lateral accel (no lag/jerk)
            a_lat = (action[0] * self.a_max) * right + (action[1] * self.a_max) * up
            self.defense_vel += (a_lat + g_vec) * dt
            self.defense_pos += self.defense_vel * dt
            self.defense_x, self.defense_y, self.defense_z = self.defense_pos
            
            # Enemy missile: pure ballistic (gravity only)
            self.enemy_vel += g_vec * dt
            self.enemy_pos += self.enemy_vel * dt
            self.enemy_x, self.enemy_y, self.enemy_z = self.enemy_pos
            self.t += dt
            
            r0 = enemy_pos_old - defense_pos_old
            r1 = self.enemy_pos - self.defense_pos
            if self._segment_sphere_intersect(r0, r1, self.collision_radius):
                self.success = True
                terminated = True
                self.done = True
                event = "hit"
                self.time_to_hit = float(self.t)
                self.terminal_event = "hit"
                break
            
            dist = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
            if dist > self.max_distance:
                truncated = True
                self.done = True
                event = "diverged"
                self.terminal_event = "diverged"
                break
            if self.defense_pos[2] < 0:
                terminated = True
                self.done = True
                event = "defense_ground"
                self.terminal_event = "defense_ground"
                break
            if self.enemy_pos[2] < 0:
                terminated = True
                self.done = True
                event = "enemy_ground"
                self.terminal_event = "enemy_ground"
                break
            if self.t >= self.t_max:
                truncated = True
                self.done = True
                event = "timeout"
                self.terminal_event = "timeout"
                break

        obs = self._get_obs()
        dist_after = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        self.min_dist = min(getattr(self, "min_dist", float("inf")), dist_after)
        self.ep_min_dist = min(self.ep_min_dist, float(dist_after))
        
        # Reward: ZEM shaping + closing only
        v_scale = 1500.0
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)
        d = float(np.linalg.norm(r)) + 1e-9
        rhat = r / d
        r_close = float(np.tanh(-float(np.dot(rhat, vrel)) / v_scale))
        phi_after, _, _ = self._phi_closest_approach()  # always (no free jackpot on termination)
        r_ca = self.w_ca * (self.gamma_shape * phi_after - phi_before)
        reward = float(0.5 * r_ca + 0.5 * 0.1 * r_close)
        final_event = event
        if terminated or truncated:
            self.terminal_event = event
        info = {"event": final_event, "t": float(self.t), "dist": float(dist_after), "r_ca": float(r_ca), "r_close": float(r_close)}
        if terminated or truncated:
            info["min_dist"] = float(self.ep_min_dist)
            info["max_action_mag"] = float(self.ep_max_action_mag)
            info["time_to_hit"] = float(self.time_to_hit) if self.time_to_hit is not None else np.nan
        return obs, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None: self.np_random = np.random.RandomState(seed)
        self.done = False
        self.success = False
        self.t = 0.0
        self.generate_enemy_missile()
        self.generate_defense_missile()
        
        self.defense_vel = np.array([
            self.defense_initial_velocity * np.cos(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_theta)
        ], dtype=np.float32)
        
        self.enemy_vel = np.array([
            self.enemy_initial_velocity * np.cos(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_theta)
        ], dtype=np.float32)
        
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        self.min_dist = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        self.ep_min_dist = float("inf")
        self.ep_max_action_mag = 0.0
        self.time_to_hit = None
        self.terminal_event = "running"
        
        return self._get_obs(), {}


# ==========================================
# EVALUATION FUNCTION (separate from training reward)
# ==========================================
def evaluate_policy(env, policy_fn, n_episodes=100, seed0=0):
    """
    Evaluate a policy and return episode-level metrics.
    This is separate from training reward - these metrics are what you actually care about.
    
    Args:
        env: missile_interception_3d environment instance
        policy_fn: Function that takes obs and returns action
        n_episodes: Number of episodes to evaluate
        seed0: Starting seed (episodes use seed0, seed0+1, ..., seed0+n_episodes-1)
    
    Returns:
        summary: Dict with aggregated metrics (hit_rate, min_dist stats, etc.)
        metrics: Dict with raw episode data
    """
    metrics = {
        "hits": 0,
        "ground_defense": 0,
        "ground_enemy": 0,
        "diverged": 0,
        "timeout": 0,
        "min_dist_list": [],
        "time_to_hit_list": [],
        "max_g_list": [],
    }

    for i in range(n_episodes):
        obs, _ = env.reset(seed=seed0 + i)
        done = False

        while not done:
            action = policy_fn(obs)  # your PPO policy OR env.calculate_pronav() baseline
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

        event = info["event"]
        metrics["min_dist_list"].append(env.ep_min_dist)
        metrics["max_g_list"].append(env.ep_max_action_mag)

        if event == "hit":
            metrics["hits"] += 1
            if env.time_to_hit is not None:
                metrics["time_to_hit_list"].append(env.time_to_hit)
        elif event == "defense_ground":
            metrics["ground_defense"] += 1
        elif event == "enemy_ground":
            metrics["ground_enemy"] += 1
        elif event == "diverged":
            metrics["diverged"] += 1
        elif event == "timeout":
            metrics["timeout"] += 1

    hit_rate = metrics["hits"] / n_episodes

    summary = {
        "hit_rate": hit_rate,
        "min_dist_mean": float(np.mean(metrics["min_dist_list"])),
        "min_dist_p50": float(np.median(metrics["min_dist_list"])),
        "min_dist_p10": float(np.percentile(metrics["min_dist_list"], 10)),
        "time_to_hit_mean": float(np.mean(metrics["time_to_hit_list"])) if metrics["time_to_hit_list"] else None,
        "max_g_mean": float(np.mean(metrics["max_g_list"])),
        "violations": {
            "defense_ground": metrics["ground_defense"],
            "enemy_ground": metrics["ground_enemy"],
            "diverged": metrics["diverged"],
            "timeout": metrics["timeout"],
        }
    }
    return summary, metrics




# def run_baseline():
#     env = missile_interception_3d()
#     outcomes = []
#     min_distances = []
#     action_loads = [] # Track if we are saturating (maxing out fins)

#     N_EPISODES = 50
#     print(f"Running {N_EPISODES} episodes of Augmented ProNav...")

#     for i in range(N_EPISODES):
#         obs, _ = env.reset(seed=i)
#         done = False
#         ep_actions = []

#         while not done:
#             # 1. Ask ProNav for the move
#             action = env.calculate_pronav()
            
#             # 2. Track how hard it's pushing (0.0 to 1.0)
#             mag = np.linalg.norm(action)
#             ep_actions.append(mag)

#             obs, reward, terminated, truncated, info = env.step(action)
#             done = terminated or truncated
        
#         outcomes.append(info['event'])
#         min_distances.append(env.ep_min_dist)
#         avg_load = np.mean(ep_actions)
#         action_loads.append(avg_load)

#         print(f"Ep {i+1:02d} | Res: {info['event']:<14} | Min Dist: {env.ep_min_dist:.1f} m | Avg G-Load: {avg_load*100:.1f}%")

#     # Final Stats
#     hits = outcomes.count("hit")
#     non_hit = [d for d, e in zip(min_distances, outcomes) if e != "hit"]
#     print("\n--- SUMMARY ---")
#     print(f"Hit Rate: {hits}/{N_EPISODES} ({hits/N_EPISODES*100:.1f}%)")
#     print("Average Miss Distance (Non-hits):", f"{np.mean(non_hit):.2f} m" if non_hit else "N/A (all hits)")
#     print(f"Average G-Loading: {np.mean(action_loads)*100:.1f}% (If >90%, missile is physically too weak)")

# if __name__ == "__main__":
#     run_baseline()

In [12]:
import numpy as np
from collections import Counter

def summarize(x, name):
    x = np.asarray(x, dtype=float)
    x = x[np.isfinite(x)]
    if x.size == 0:
        print(f"{name:18s} (empty)")
        return
    print(
        f"{name:18s} mean={x.mean(): .6f}  p50={np.median(x): .6f}  "
        f"p10={np.percentile(x,10): .6f}  p90={np.percentile(x,90): .6f}  "
        f"min={x.min(): .6f}  max={x.max(): .6f}"
    )

def probe_reward_scales(env, policy="random", n_episodes=20, seed0=0, v_scale=1500.0, max_steps=20000):
    """
    Logs per-STEP scales for:
      - reward (returned)
      - r_ca (info + recomputed), r_close (info + recomputed)
      - phi_before / phi_after (closest-approach potential)
      - dstar_before, tstar_before
    """
    events = Counter()

    S = {
        "reward": [],
        "action_mag": [],

        "r_ca_info": [],
        "r_ca_calc": [],
        "r_close_info": [],
        "r_close_calc": [],

        "phi_before": [],
        "phi_after": [],

        "dstar_before": [],
        "tstar_before": [],
        "dist_before": [],
        "dist_after": [],

        "abs_d_rca": [],
        "abs_d_rclose": [],
    }

    ep_sums = []  # (event, sum_r_ca, sum_01_r_close) per episode

    for ep in range(n_episodes):
        obs, _ = env.reset(seed=seed0 + ep)
        done = False
        steps = 0
        sum_r_ca_ep = 0.0
        sum_01_r_close_ep = 0.0

        while not done:
            steps += 1
            if steps > max_steps:
                # safety breaker in case something goes infinite
                events["max_steps_break"] += 1
                break

            if policy == "random":
                action = env.action_space.sample()
            elif policy == "pronav":
                action = env.calculate_pronav()
            else:
                raise ValueError("policy must be 'random' or 'pronav'")

            # ----- PRE-STEP GEOMETRY (closest-approach) -----
            dist_b = float(np.linalg.norm(env.enemy_pos - env.defense_pos))
            phi_b, dstar_b, tstar_b = env._phi_closest_approach()

            # ----- STEP -----
            obs, reward, terminated, truncated, info = env.step(action)
            done = bool(terminated or truncated)

            # ----- POST-STEP GEOMETRY (always compute, same as env) -----
            dist_a = float(np.linalg.norm(env.enemy_pos - env.defense_pos))
            phi_a, _, _ = env._phi_closest_approach()

            # recompute shaping from geometry
            r_ca_calc = float(env.w_ca * (env.gamma_shape * phi_a - phi_b))

            # recompute closing reward from post-step state (should match your env)
            r = (env.enemy_pos - env.defense_pos).astype(np.float64)
            vrel = (env.enemy_vel - env.defense_vel).astype(np.float64)
            d = float(np.linalg.norm(r)) + 1e-9
            rhat = r / d
            r_close_calc = float(np.tanh(-float(np.dot(rhat, vrel)) / v_scale))

            # pull info values if present
            r_ca_info = float(info.get("r_ca", np.nan))
            r_close_info = float(info.get("r_close", np.nan))

            # log
            S["reward"].append(float(reward))
            S["action_mag"].append(float(np.linalg.norm(action)))

            S["r_ca_info"].append(r_ca_info)
            S["r_ca_calc"].append(r_ca_calc)
            S["r_close_info"].append(r_close_info)
            S["r_close_calc"].append(r_close_calc)

            S["phi_before"].append(float(phi_b))
            S["phi_after"].append(float(phi_a))

            S["dstar_before"].append(float(dstar_b))
            S["tstar_before"].append(float(tstar_b))
            S["dist_before"].append(dist_b)
            S["dist_after"].append(dist_a)

            S["abs_d_rca"].append(float(np.abs(r_ca_info - r_ca_calc)) if np.isfinite(r_ca_info) else np.nan)
            S["abs_d_rclose"].append(float(np.abs(r_close_info - r_close_calc)) if np.isfinite(r_close_info) else np.nan)

            if np.isfinite(r_ca_info):
                sum_r_ca_ep += r_ca_info
            if np.isfinite(r_close_info):
                sum_01_r_close_ep += 0.1 * r_close_info

        ev = str(info.get("event", "unknown"))
        events[ev] += 1
        ep_sums.append((ev, sum_r_ca_ep, sum_01_r_close_ep))

    print("\n==============================")
    print(f"Policy: {policy} | Episodes: {n_episodes}")
    print("Events:", dict(events))

    # Core scales
    summarize(S["reward"],        "reward(step)")
    summarize(S["r_ca_calc"],     "r_ca_calc")
    summarize(S["r_ca_info"],     "r_ca_info")
    summarize(S["r_close_calc"],  "r_close_calc")
    summarize(S["r_close_info"],  "r_close_info")

    # Closest-approach geometry
    summarize(S["dstar_before"],   "dstar_before")
    summarize(S["tstar_before"],   "tstar_before")
    summarize(S["phi_before"],     "phi_before")
    summarize(S["dist_before"],   "dist_before")
    summarize(S["action_mag"],     "action_mag")

    # Consistency checks (should be ~0 if info matches math)
    summarize(S["abs_d_rca"],    "|r_ca_info-calc|")
    summarize(S["abs_d_rclose"], "|r_close_info-calc|")

    # Per-episode return components by terminal event (what PPO optimizes)
    from collections import defaultdict
    by_event = defaultdict(lambda: {"sum_r_ca": [], "sum_01_r_close": []})
    for ev, s_ca, s_close in ep_sums:
        by_event[ev]["sum_r_ca"].append(s_ca)
        by_event[ev]["sum_01_r_close"].append(s_close)
    print("\n--- Per-episode sums by terminal event (return components) ---")
    for ev in sorted(by_event.keys()):
        ca_list = by_event[ev]["sum_r_ca"]
        close_list = by_event[ev]["sum_01_r_close"]
        n = len(ca_list)
        mean_ca = np.mean(ca_list) if ca_list else np.nan
        mean_close = np.mean(close_list) if close_list else np.nan
        mean_return = 0.5 * mean_ca + 0.5 * mean_close  # reward = 0.5*r_ca + 0.5*0.1*r_close
        print(f"  {ev:18s}  n={n:3d}  mean(sum_r_ca)={mean_ca: .4f}  mean(sum_0.1*r_close)={mean_close: .4f}  mean(return)={mean_return: .4f}")

# --- Run both ---
env = missile_interception_3d()
probe_reward_scales(env, policy="random", n_episodes=20, seed0=0)

env = missile_interception_3d()
probe_reward_scales(env, policy="pronav", n_episodes=20, seed0=0)


Policy: random | Episodes: 20
Events: {'enemy_ground': 20}
reward(step)       mean=-0.030056  p50=-0.049994  p10=-0.051672  p90= 0.041935  min=-0.052805  max= 0.063826
r_ca_calc          mean=-0.003830  p50=-0.005062  p10=-0.006148  p90= 0.000236  min=-0.033897  max= 0.031497
r_ca_info          mean=-0.003830  p50=-0.005062  p10=-0.006148  p90= 0.000236  min=-0.033897  max= 0.031497
r_close_calc       mean=-0.562815  p50=-0.949687  p10=-0.973627  p90= 0.884774  min=-0.987381  max= 0.982641
r_close_info       mean=-0.562815  p50=-0.949687  p10=-0.973627  p90= 0.884774  min=-0.987381  max= 0.982641
dstar_before       mean= 186486.449533  p50= 159883.410032  p10= 53745.385761  p90= 364399.151761  min= 22805.965797  max= 567419.871967
tstar_before       mean= 3.780076  p50= 0.000000  p10= 0.000000  p90= 16.917025  min= 0.000000  max= 54.858541
phi_before         mean=-3.729729  p50=-3.197668  p10=-7.287983  p90=-1.074908  min=-11.348397  max=-0.456119
dist_before        mean= 192287.22142

In [13]:
# ========== STEP-0 ACTION SENSITIVITY TEST ==========
# Snapshot/restore env so each action starts from the same state.

def get_env_state(env):
    return {
        "enemy_pos": env.enemy_pos.copy(),
        "defense_pos": env.defense_pos.copy(),
        "enemy_vel": env.enemy_vel.copy(),
        "defense_vel": env.defense_vel.copy(),
        "t": float(env.t),
        "done": bool(getattr(env, "done", False)),
        "success": bool(getattr(env, "success", False)),
        "min_dist": float(getattr(env, "min_dist", np.inf)),
        "ep_min_dist": float(getattr(env, "ep_min_dist", np.inf)),
        "ep_max_action_mag": float(getattr(env, "ep_max_action_mag", 0.0)),
        "time_to_hit": None if getattr(env, "time_to_hit", None) is None else float(env.time_to_hit),
        "terminal_event": str(getattr(env, "terminal_event", "running")),
        "enemy_x": float(getattr(env, "enemy_x", env.enemy_pos[0])),
        "enemy_y": float(getattr(env, "enemy_y", env.enemy_pos[1])),
        "enemy_z": float(getattr(env, "enemy_z", env.enemy_pos[2])),
        "defense_x": float(getattr(env, "defense_x", env.defense_pos[0])),
        "defense_y": float(getattr(env, "defense_y", env.defense_pos[1])),
        "defense_z": float(getattr(env, "defense_z", env.defense_pos[2])),
    }

def set_env_state(env, S):
    env.enemy_pos = S["enemy_pos"].copy()
    env.defense_pos = S["defense_pos"].copy()
    env.enemy_vel = S["enemy_vel"].copy()
    env.defense_vel = S["defense_vel"].copy()
    env.t = float(S["t"])
    env.done = bool(S["done"])
    env.success = bool(S["success"])
    env.min_dist = float(S["min_dist"])
    env.ep_min_dist = float(S["ep_min_dist"])
    env.ep_max_action_mag = float(S["ep_max_action_mag"])
    env.time_to_hit = S["time_to_hit"]
    env.terminal_event = S["terminal_event"]
    env.enemy_x, env.enemy_y, env.enemy_z = S["enemy_x"], S["enemy_y"], S["enemy_z"]
    env.defense_x, env.defense_y, env.defense_z = S["defense_x"], S["defense_y"], S["defense_z"]


def eval_one_step(env, action, v_scale=1500.0):
    phi_b, dstar_b, tstar_b = env._phi_closest_approach()
    obs2, reward, terminated, truncated, info = env.step(action)
    done = bool(terminated or truncated)
    phi_a, dstar_a, tstar_a = env._phi_closest_approach()
    r_ca = float(info.get("r_ca", np.nan))
    r_close = float(info.get("r_close", np.nan))
    return {
        "reward": float(reward),
        "r_ca": r_ca,
        "r_close": r_close,
        "done": done,
        "event": str(info.get("event", "unknown")),
        "dstar_before": float(dstar_b),
        "dstar_after": float(dstar_a),
        "ddstar": float(dstar_a - dstar_b),
        "tstar_before": float(tstar_b),
        "tstar_after": float(tstar_a),
        "phi_before": float(phi_b),
        "phi_after": float(phi_a),
    }


def sample_actions_like_initial_ppo(n, sigma=1.0, rng=None):
    if rng is None:
        rng = np.random.RandomState(0)
    a = rng.randn(n, 2).astype(np.float32) * float(sigma)
    a = np.clip(a, -1.0, 1.0)
    return a


def scan_initial_state(env, n_actions=4096, sigma=1.0, seed=0, v_scale=1500.0, topk=10):
    obs, _ = env.reset(seed=seed)
    S0 = get_env_state(env)
    rng = np.random.RandomState(seed + 12345)
    actions = sample_actions_like_initial_ppo(n_actions, sigma=sigma, rng=rng)
    rewards = np.zeros(n_actions, dtype=np.float64)
    r_ca = np.zeros(n_actions, dtype=np.float64)
    r_close = np.zeros(n_actions, dtype=np.float64)
    ddstar = np.zeros(n_actions, dtype=np.float64)
    done = np.zeros(n_actions, dtype=bool)
    for i in range(n_actions):
        set_env_state(env, S0)
        out = eval_one_step(env, actions[i], v_scale=v_scale)
        rewards[i] = out["reward"]
        r_ca[i] = out["r_ca"]
        r_close[i] = out["r_close"]
        ddstar[i] = out["ddstar"]
        done[i] = out["done"]
    def q(x, p): return float(np.percentile(x, p))
    frac_close_pos = float(np.mean(r_close > 0))
    frac_ca_pos = float(np.mean(r_ca > 0))
    frac_ca_pos_given_close = float(np.mean((r_ca > 0) & (r_close > 0)) / max(np.mean(r_close > 0), 1e-9))
    print("\n=== STEP-0 ACTION SCAN ===")
    print(f"seed={seed} | n_actions={n_actions} | sigma={sigma}")
    print(f"frac(r_close>0)={frac_close_pos:.3f}")
    print(f"frac(r_ca>0)={frac_ca_pos:.3f}")
    print(f"frac(r_ca>0 | r_close>0)={frac_ca_pos_given_close:.3f}")
    print("reward:   p10={:.4f}  p50={:.4f}  p90={:.4f}  max={:.4f}".format(q(rewards,10), q(rewards,50), q(rewards,90), float(rewards.max())))
    print("r_close:  p10={:.4f}  p50={:.4f}  p90={:.4f}  max={:.4f}".format(q(r_close,10), q(r_close,50), q(r_close,90), float(r_close.max())))
    print("r_ca:     p10={:.4f}  p50={:.4f}  p90={:.4f}  max={:.4f}".format(q(r_ca,10), q(r_ca,50), q(r_ca,90), float(r_ca.max())))
    print("Δd*:      p10={:.2f}  p50={:.2f}  p90={:.2f}  min={:.2f}".format(q(ddstar,10), q(ddstar,50), q(ddstar,90), float(ddstar.min())))
    idx_best_reward = np.argsort(-rewards)[:topk]
    idx_best_ca = np.argsort(-r_ca)[:topk]
    print("\nTop actions by total reward:")
    for j in idx_best_reward:
        print(f"  a={actions[j]}  reward={rewards[j]:+.4f}  r_ca={r_ca[j]:+.4f}  r_close={r_close[j]:+.4f}  Δd*={ddstar[j]:+.1f}")
    print("\nTop actions by geometry shaping r_ca:")
    for j in idx_best_ca:
        print(f"  a={actions[j]}  r_ca={r_ca[j]:+.4f}  reward={rewards[j]:+.4f}  r_close={r_close[j]:+.4f}  Δd*={ddstar[j]:+.1f}")
    set_env_state(env, S0)
    return {"actions": actions, "reward": rewards, "r_ca": r_ca, "r_close": r_close, "ddstar": ddstar, "done": done}


def scan_many_initial_states(env, seeds, n_actions=2048, sigma=1.0):
    stats = []
    for s in seeds:
        out = scan_initial_state(env, n_actions=n_actions, sigma=sigma, seed=s)
        stats.append({
            "seed": s,
            "frac_close_pos": float(np.mean(out["r_close"] > 0)),
            "frac_ca_pos": float(np.mean(out["r_ca"] > 0)),
            "reward_p90": float(np.percentile(out["reward"], 90)),
            "reward_p50": float(np.percentile(out["reward"], 50)),
            "reward_gap_p90_p50": float(np.percentile(out["reward"], 90) - np.percentile(out["reward"], 50)),
            "ca_gap_p90_p50": float(np.percentile(out["r_ca"], 90) - np.percentile(out["r_ca"], 50)),
        })
    print("\n=== SUMMARY ACROSS INITIAL STATES ===")
    for row in stats:
        print(f"seed={row['seed']:>3d} | frac_close>0={row['frac_close_pos']:.2f} | frac_ca>0={row['frac_ca_pos']:.2f} | reward(p90-p50)={row['reward_gap_p90_p50']:.4f} | r_ca(p90-p50)={row['ca_gap_p90_p50']:.4f}")
    return stats


# ---------- K-step scan: 1 step with candidate action, then K-1 steps ProNav ----------
def eval_one_step_then_pronav_k_steps(env, action, K, v_scale=1500.0):
    """Take action for 1 step, then ProNav for K-1 steps. Caller must restore S0 before each call."""
    total_reward = 0.0
    obs, reward, term, trunc, info = env.step(action)
    total_reward += float(reward)
    r_ca_first = float(info.get("r_ca", np.nan))
    done = bool(term or trunc)
    steps_taken = 1
    while not done and steps_taken < K:
        a_pronav = env.calculate_pronav()
        obs, reward, term, trunc, info = env.step(a_pronav)
        total_reward += float(reward)
        done = bool(term or trunc)
        steps_taken += 1
    min_dist = float(getattr(env, "ep_min_dist", np.inf))
    _, dstar_final, _ = env._phi_closest_approach()
    return {
        "total_reward": total_reward,
        "min_dist": min_dist,
        "dstar_final": float(dstar_final),
        "done": done,
        "event": str(info.get("event", "unknown")),
        "steps_taken": steps_taken,
        "r_ca_first": r_ca_first,
    }


def scan_initial_state_k_steps(env, n_actions=2048, K=20, sigma=1.0, seed=0, topk=10, v_scale=1500.0):
    """Same as step-0 scan but evaluate each action by: 1 step action + (K-1) steps ProNav. Report min_dist, d*, return."""
    obs, _ = env.reset(seed=seed)
    S0 = get_env_state(env)
    rng = np.random.RandomState(seed + 12345)
    actions = sample_actions_like_initial_ppo(n_actions, sigma=sigma, rng=rng)
    total_reward = np.zeros(n_actions, dtype=np.float64)
    min_dist = np.zeros(n_actions, dtype=np.float64)
    dstar_final = np.zeros(n_actions, dtype=np.float64)
    r_ca_first = np.zeros(n_actions, dtype=np.float64)
    r_ca_first[:] = np.nan
    done_arr = np.zeros(n_actions, dtype=bool)
    event_counts = {}
    for i in range(n_actions):
        set_env_state(env, S0)
        out = eval_one_step_then_pronav_k_steps(env, actions[i], K=K, v_scale=v_scale)
        total_reward[i] = out["total_reward"]
        min_dist[i] = out["min_dist"]
        dstar_final[i] = out["dstar_final"]
        done_arr[i] = out["done"]
        r_ca_first[i] = out["r_ca_first"]
        e = out["event"]
        event_counts[e] = event_counts.get(e, 0) + 1
    def q(x, p): return float(np.percentile(x, p))
    set_env_state(env, S0)
    print("\n=== K-STEP ACTION SCAN (1 step candidate + {} steps ProNav) ===".format(K - 1))
    print("seed={} | n_actions={} | K={} | sigma={}".format(seed, n_actions, K, sigma))
    print("Terminal events:", event_counts)
    print("total_reward: p10={:.4f}  p50={:.4f}  p90={:.4f}  max={:.4f}".format(q(total_reward, 10), q(total_reward, 50), q(total_reward, 90), float(total_reward.max())))
    print("min_dist:     p10={:.1f}  p50={:.1f}  p90={:.1f}  min={:.1f}".format(q(min_dist, 10), q(min_dist, 50), q(min_dist, 90), float(min_dist.min())))
    print("dstar_final:  p10={:.1f}  p50={:.1f}  p90={:.1f}  min={:.1f}".format(q(dstar_final, 10), q(dstar_final, 50), q(dstar_final, 90), float(dstar_final.min())))
    idx_best_return = np.argsort(-total_reward)[:topk]
    idx_best_min_dist = np.argsort(min_dist)[:topk]
    idx_best_dstar = np.argsort(dstar_final)[:topk]
    print("\nTop actions by total reward (1+K steps):")
    for j in idx_best_return:
        print("  a={}  total_reward={:+.4f}  min_dist={:.1f}  d*={:.1f}".format(actions[j], total_reward[j], min_dist[j], dstar_final[j]))
    print("\nTop actions by min_dist (lower is better):")
    for j in idx_best_min_dist:
        print("  a={}  min_dist={:.1f}  total_reward={:+.4f}  d*={:.1f}".format(actions[j], min_dist[j], total_reward[j], dstar_final[j]))
    valid_ca = np.isfinite(r_ca_first)
    if np.any(valid_ca):
        corr_ca_min = np.corrcoef(r_ca_first[valid_ca], min_dist[valid_ca])[0, 1] if np.sum(valid_ca) > 1 else np.nan
        print("\nCorrelation(r_ca at step 0, min_dist after K steps): {:.4f} (negative = good r_ca -> lower miss)".format(corr_ca_min if np.isfinite(corr_ca_min) else np.nan))
    set_env_state(env, S0)
    return {"actions": actions, "total_reward": total_reward, "min_dist": min_dist, "dstar_final": dstar_final, "r_ca_first": r_ca_first, "done": done_arr}


# Run: single initial state
env = missile_interception_3d()
scan_initial_state(env, n_actions=4096, sigma=1.0, seed=0)

# Run: many initial states (full output per seed; summary at end)
scan_many_initial_states(env, seeds=list(range(10)), n_actions=2048, sigma=1.0)

# K-step scan: do best step-0 actions stay good after 2s? (1 step candidate + 19 steps ProNav)
scan_initial_state_k_steps(env, n_actions=2048, K=20, sigma=1.0, seed=0)


=== STEP-0 ACTION SCAN ===
seed=0 | n_actions=4096 | sigma=1.0
frac(r_close>0)=1.000
frac(r_ca>0)=0.494
frac(r_ca>0 | r_close>0)=0.494
reward:   p10=0.0369  p50=0.0473  p90=0.0582  max=0.0604
r_close:  p10=0.9500  p50=0.9514  p90=0.9529  max=0.9532
r_ca:     p10=-0.0212  p50=-0.0005  p90=0.0210  max=0.0255
Δd*:      p10=-1043.98  p50=32.01  p90=1066.80  min=-1269.69

Top actions by total reward:
  a=[ 0.11187453 -1.        ]  reward=+0.0604  r_ca=+0.0255  r_close=+0.9532  Δd*=-1269.7
  a=[ 0.1135906 -1.       ]  reward=+0.0604  r_ca=+0.0255  r_close=+0.9532  Δd*=-1269.7
  a=[ 0.10361419 -1.        ]  reward=+0.0604  r_ca=+0.0255  r_close=+0.9532  Δd*=-1269.7
  a=[ 0.11148106 -1.        ]  reward=+0.0604  r_ca=+0.0255  r_close=+0.9532  Δd*=-1269.7
  a=[ 0.10172243 -1.        ]  reward=+0.0604  r_ca=+0.0255  r_close=+0.9532  Δd*=-1269.6
  a=[ 0.10507097 -1.        ]  reward=+0.0604  r_ca=+0.0255  r_close=+0.9532  Δd*=-1269.6
  a=[ 0.10067383 -1.        ]  reward=+0.0604  r_ca=+0.0255  r

[{'seed': 0,
  'frac_close_pos': 1.0,
  'frac_ca_pos': 0.47998046875,
  'reward_p90': 0.057961663303557134,
  'reward_p50': 0.04690812801554631,
  'reward_gap_p90_p50': 0.011053535288010827,
  'ca_gap_p90_p50': 0.0219626521167243},
 {'seed': 1,
  'frac_close_pos': 1.0,
  'frac_ca_pos': 0.4931640625,
  'reward_p90': 0.05299899088856686,
  'reward_p50': 0.04788358455548042,
  'reward_gap_p90_p50': 0.005115406333086443,
  'ca_gap_p90_p50': 0.010118706780408469},
 {'seed': 2,
  'frac_close_pos': 1.0,
  'frac_ca_pos': 0.5068359375,
  'reward_p90': 0.05571018787214409,
  'reward_p50': 0.04686736244438779,
  'reward_gap_p90_p50': 0.008842825427756294,
  'ca_gap_p90_p50': 0.017492544198548596},
 {'seed': 3,
  'frac_close_pos': 1.0,
  'frac_ca_pos': 0.50927734375,
  'reward_p90': 0.06345264024546382,
  'reward_p50': 0.049075707932783263,
  'reward_gap_p90_p50': 0.014376932312680557,
  'ca_gap_p90_p50': 0.028704914869653174},
 {'seed': 4,
  'frac_close_pos': 1.0,
  'frac_ca_pos': 0.49658203125,
