In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

class missile_interception_3d(gym.Env):
    def __init__(self):
        # 1. Define Action Space (The Joystick: Left/Right, Up/Down)
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(2, ), dtype=np.float32)
        
        # 2. Define Observation Space (expanded for better control)
        # Size = 3 (Rel Pos) + 3 (Rel Vel) + 2 (Curr Accel) + 2 (dist_n, vclose_n) + 2 (def_z, def_vz) = 12
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32)

        self.np_random = np.random.RandomState()
        
        # 3. Time Settings
        self.dt_act = 0.1             # AI decides every 0.1 seconds
        self.n_substeps = 10          # Physics runs 10x faster
        self.dt_sim = self.dt_act / self.n_substeps # 0.01 seconds
        self.t_max = 650.0            # Maximum simulation time (seconds)

        # 4. Physical Limits - SIMPLIFIED UFO PHYSICS
        self.a_max = 350.0   # Max G-force (m/s^2) ~35G
        self.da_max = 2500.0 # Jerk Limit (m/s^3)
        self.tau = 0.05      # Airframe Lag (seconds)
        self.g = 9.81        # Gravity
        self.collision_radius = 10.0  # Hit radius (meters)
        self.max_distance = 4_000_000.0  # Maximum distance before truncation

        # --- Mini curriculum (TEMPORARILY: easy only to prove baseline) ---
        self.p_easy = 1.0                   # 100% easy scenarios (disable hard until baseline works)
        self.range_min = 70_000.0           # 70 km (meters)
        self.range_easy_max = 200_000.0     # 200 km (meters)  [easy]
        self.range_hard_max = 1_000_000.0   # 1000 km (meters) [hard] - not used when p_easy=1.0

        self.targetbox_x_min = -15000
        self.targetbox_x_max = 15000
        self.targetbox_y_min = -15000
        self.targetbox_y_max = 15000

    def generate_enemy_missile(self):
        # --- Mini curriculum sampling ---
        # Most of the time: easier ranges; sometimes: full hard range
        if self.np_random.rand() < self.p_easy:
            self.range_max_used = self.range_easy_max
        else:
            self.range_max_used = self.range_hard_max

        range_min = self.range_min

        self.attack_target_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.attack_target_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)

        self.enemy_launch_angle = self.np_random.uniform(0, 2 * np.pi)
        self.enemy_theta = self.np_random.uniform(0.523599, 1.0472)  # 30 to 60 degrees

        # Make sure range_max_used is valid relative to range_min
        self.range_max_used = max(self.range_max_used, range_min + 1.0)

        lower_limit = np.sqrt((range_min * self.g) / np.sin(2 * self.enemy_theta))
        upper_limit = np.sqrt((self.range_max_used * self.g) / np.sin(2 * self.enemy_theta))

        self.enemy_initial_velocity = self.np_random.uniform(lower_limit, upper_limit)

        ground_range = (
            self.enemy_initial_velocity * np.cos(self.enemy_theta)
            * (2 * self.enemy_initial_velocity * np.sin(self.enemy_theta) / self.g)
        )

        self.enemy_launch_x = self.attack_target_x + ground_range * np.cos(self.enemy_launch_angle)
        self.enemy_launch_y = self.attack_target_y + ground_range * np.sin(self.enemy_launch_angle)
        self.enemy_z = 0

        self.enemy_x = self.enemy_launch_x
        self.enemy_y = self.enemy_launch_y
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)

        self.enemy_azimuth = (self.enemy_launch_angle + np.pi) % (2 * np.pi)

    def generate_defense_missile(self):
        self.defense_launch_x = self.np_random.uniform(self.targetbox_x_min, self.targetbox_x_max)
        self.defense_launch_y = self.np_random.uniform(self.targetbox_y_min, self.targetbox_y_max)
        
        # Point defense missile at enemy launch position (removes impossible geometry)
        dx = self.enemy_launch_x - self.defense_launch_x
        dy = self.enemy_launch_y - self.defense_launch_y
        self.defense_azimuth = np.arctan2(dy, dx)
        
        self.defense_theta = 0.785398
        
        # Scale defense velocity with scenario difficulty (fixes impossible hard scenarios)
        # Base velocity for easy scenarios (200km range)
        base_velocity = 3000.0
        if hasattr(self, 'range_max_used'):
            # Scale velocity proportionally to range, but cap at reasonable max
            # Hard scenarios (1000km) get ~1.5x velocity to make intercept feasible
            velocity_scale = min(self.range_max_used / self.range_easy_max, 1.5)
            self.defense_initial_velocity = base_velocity * velocity_scale
        else:
            # Fallback if called out of order (shouldn't happen)
            self.defense_initial_velocity = base_velocity
        
        self.defense_x = self.defense_launch_x
        self.defense_y = self.defense_launch_y
        self.defense_z = 0
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)

        self.defense_ax = 0
        self.defense_ay = 0
        self.defense_az = 0
    
    def calculate_pronav(self):
        r = self.enemy_pos - self.defense_pos
        v = self.enemy_vel - self.defense_vel
        r_mag = np.linalg.norm(r)
        omega = np.cross(v, r) / (r_mag ** 2 + 1e-9)

        vc = -np.dot(r, v) / (r_mag + 1e-9)
        N = 3.0 
        a_cmd = N * vc * np.cross(omega, r/r_mag)

        # === NEW: Convert to action space format ===
        
        # 1. Get the lateral basis (same as in step())
        right, up = self._compute_lateral_basis(self.defense_vel)
        
        # 2. Project the 3D acceleration onto the lateral plane
        a_right = np.dot(a_cmd, right)
        a_up = np.dot(a_cmd, up)
        
        # 3. Normalize by a_max to get [-1, 1] range
        action = np.array([
            a_right / self.a_max,
            a_up / self.a_max
        ], dtype=np.float32)
        
        # 4. Clip to action space bounds
        action = np.clip(action, -1.0, 1.0)
        
        return action
    
    def _rate_limit_norm(self, a_cmd, a_prev, da_max, dt):
        """Norm-based rate limiter: limits the magnitude of acceleration change."""
        delta = a_cmd - a_prev
        max_delta = da_max * dt
        dnorm = float(np.linalg.norm(delta))
        if dnorm <= max_delta or dnorm < 1e-9:
            return a_cmd
        return a_prev + delta * (max_delta / dnorm)
    
    def _segment_sphere_intersect(self, r0, r1, r_hit):
        """Check if line segment from r0 to r1 intersects sphere of radius r_hit."""
        dr = r1 - r0
        dr_norm_sq = float(np.dot(dr, dr))
        
        if dr_norm_sq < 1e-12:
            return float(np.dot(r0, r0)) <= r_hit * r_hit
        
        s_star = -float(np.dot(r0, dr)) / dr_norm_sq
        s_star = max(0.0, min(1.0, s_star))
        
        r_closest = r0 + s_star * dr
        
        return float(np.dot(r_closest, r_closest)) <= r_hit * r_hit
    
    def _get_obs(self):
        # Relative state
        rel_pos = self.enemy_pos - self.defense_pos          # (3,)
        rel_vel = self.enemy_vel - self.defense_vel          # (3,)

        dist = float(np.linalg.norm(rel_pos)) + 1e-6

        # Closing speed along line-of-sight (LOS)
        # v_close > 0 means closing, v_close < 0 means opening
        v_close = -float(np.dot(rel_pos, rel_vel)) / dist

        # Lateral basis from defense velocity
        right, up = self._compute_lateral_basis(self.defense_vel)

        # Current lateral accel in that basis (already bounded-ish)
        a_lat = np.array([
            float(np.dot(self.a_actual, right)) / (self.a_max + 1e-9),
            float(np.dot(self.a_actual, up)) / (self.a_max + 1e-9),
        ], dtype=np.float32)

        # --- Normalize / compress features to sane ranges ---
        # Position scale: use hard range max so typical values land ~[-1,1]
        pos_scale = self.range_hard_max  # 1_000_000.0
        # Velocity scale: your speeds are on the order of ~1-4 km/s
        vel_scale = 4000.0

        rel_pos_n = (rel_pos / pos_scale).astype(np.float32)
        rel_vel_n = (rel_vel / vel_scale).astype(np.float32)

        # --- Replace t_go with stable features (fixes saturation bug) ---
        # Distance normalized: clip to reasonable range
        dist_n = np.clip(dist / 1_000_000.0, 0.0, 4.0).astype(np.float32)
        # Closing speed normalized: preserve sign (negative = opening, positive = closing)
        vclose_n = np.clip(v_close / 3000.0, -2.0, 2.0).astype(np.float32)
        dist_vclose_feat = np.array([dist_n, vclose_n], dtype=np.float32)

        # --- Defense missile own state (critical for ground avoidance) ---
        # Altitude normalized (typical max ~100km = 100000m)
        def_z_n = np.clip(self.defense_pos[2] / 100_000.0, -1.0, 2.0).astype(np.float32)
        # Vertical velocity normalized (typical range -3000 to +3000 m/s)
        def_vz_n = np.clip(self.defense_vel[2] / 3000.0, -2.0, 2.0).astype(np.float32)
        def_state_feat = np.array([def_z_n, def_vz_n], dtype=np.float32)

        obs = np.concatenate([rel_pos_n, rel_vel_n, a_lat, dist_vclose_feat, def_state_feat], axis=0).astype(np.float32)
        return obs

    def _compute_lateral_basis(self, velocity):
        """Compute right/up basis vectors perpendicular to velocity (forward direction)."""
        speed = np.linalg.norm(velocity)
        if speed < 1.0:
            forward = np.array([1.0, 0.0, 0.0], dtype=np.float32)
        else:
            forward = velocity / speed
        
        if abs(forward[2]) > 0.99:
            ref_axis = np.array([1.0, 0.0, 0.0], dtype=np.float32)
            if abs(np.dot(forward, ref_axis)) > 0.9:
                ref_axis = np.array([0.0, 1.0, 0.0], dtype=np.float32)
            right_raw = np.cross(forward, ref_axis)
            right = right_raw / (np.linalg.norm(right_raw) + 1e-6)
        else:
            world_up = np.array([0.0, 0.0, 1.0], dtype=np.float32)
            right_raw = np.cross(forward, world_up)
            right = right_raw / (np.linalg.norm(right_raw) + 1e-6)
        
        up_raw = np.cross(right, forward)
        up = up_raw / (np.linalg.norm(up_raw) + 1e-6)
        
        return right, up

    def step(self, action):
        if getattr(self, "done", False):
            obs = self._get_obs()
            return obs, 0.0, True, False, {"event": "called_step_after_done", "dist": self.relative_distances[-1] if self.relative_distances else float('inf'), "t": self.t}
        
        action = np.clip(action, -1.0, 1.0).astype(np.float32)
        mag = np.linalg.norm(action)
        if mag > 1.0:
            action = action / mag

        dist_before = float(np.linalg.norm(self.enemy_pos - self.defense_pos))

        terminated = False
        truncated = False
        event = "running"
        
        for _ in range(self.n_substeps):
            dt = self.dt_sim
            
            enemy_pos_old = self.enemy_pos.copy()
            defense_pos_old = self.defense_pos.copy()
            
            right, up = self._compute_lateral_basis(self.defense_vel)
            
            a_target_world = (action[0] * self.a_max * right) + (action[1] * self.a_max * up)
            
            self.a_cmd_prev = self._rate_limit_norm(a_target_world, self.a_cmd_prev, self.da_max, dt)
            
            self.a_actual += (self.a_cmd_prev - self.a_actual) * (dt / self.tau)
            
            # ============================================================
            # UFO PHYSICS: Just guidance + gravity (NO DRAG, NO THRUST)
            # ============================================================
            accel_gravity = np.array([0.0, 0.0, -self.g], dtype=np.float32)
            
            # Defense missile: AI control + gravity
            self.defense_vel += (self.a_actual + accel_gravity) * dt
            self.defense_pos += self.defense_vel * dt
            self.defense_x, self.defense_y, self.defense_z = self.defense_pos
            
            # Enemy missile: Pure ballistic
            self.enemy_vel += accel_gravity * dt
            self.enemy_pos += self.enemy_vel * dt
            self.enemy_x, self.enemy_y, self.enemy_z = self.enemy_pos
            
            self.t += dt
            
            # Collision checks
            r0 = enemy_pos_old - defense_pos_old
            r1 = self.enemy_pos - self.defense_pos
            if self._segment_sphere_intersect(r0, r1, self.collision_radius):
                self.success = True
                terminated = True
                self.done = True
                event = "hit"
                break
            
            dist = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
            if dist > self.max_distance:
                truncated = True
                self.done = True
                event = "diverged"
                break
                
            if self.defense_pos[2] < 0:
                terminated = True
                self.done = True
                event = "defense_ground"
                break
            
            if self.enemy_pos[2] < 0:
                terminated = True
                self.done = True
                event = "enemy_ground"
                break
            
            if self.t >= self.t_max:
                truncated = True
                self.done = True
                event = "timeout"
                break

        self.enemy_path.append(self.enemy_pos.copy())
        self.defense_path.append(self.defense_pos.copy())
        self.relative_distances.append(float(np.linalg.norm(self.enemy_pos - self.defense_pos)))
        self.times.append(self.t)
        
        obs = self._get_obs()
        
        dist_after = float(np.linalg.norm(self.enemy_pos - self.defense_pos))
        
        # Track closest approach (important for grading near-misses)
        self.min_dist = min(getattr(self, "min_dist", float("inf")), dist_after)
        
        # Relative vectors at END of step
        r = (self.enemy_pos - self.defense_pos).astype(np.float64)
        vrel = (self.enemy_vel - self.defense_vel).astype(np.float64)
        
        d = float(np.linalg.norm(r)) + 1e-9
        rhat = r / d
        
        # Range-rate (meters/sec). Negative = closing.
        d_dot = float(np.dot(rhat, vrel))
        
        # ------------------------------------------------------------
        # Dense shaping that does NOT saturate at long range
        # ------------------------------------------------------------
        
        # 1) Distance progress this action step (meters -> reward units)
        # Set d_scale so that ~100m progress ≈ +1 reward (since dt_act = 0.1)
        d_scale = 100.0
        r_progress = (dist_before - dist_after) / d_scale
        
        # 2) Closing speed signal (heavily nerfed to prevent loitering)
        # v_scale should be "typical" closing speed scale (m/s)
        v_scale = 1500.0
        r_close = np.tanh((-d_dot) / v_scale)
        
        # 3) Small time penalty to prevent endless loitering
        r_time = -0.001

        # Track reward components for debugging
        self.sum_r_progress += float(r_progress)
        self.sum_r_close += float(r_close)
        
        # Weighted sum: heavily nerf r_close to prevent loitering behavior
        # Focus on distance progress, not just closing speed
        reward = 1.0 * r_progress + 0.1 * r_close + r_time
        
        # ------------------------------------------------------------
        # Terminal shaping (fixed to make suicide unattractive)
        # ------------------------------------------------------------
        HIT_BONUS = 10000.0      # Much higher reward for success
        FAIL_PENALTY = 2000.0
        CRASH_PENALTY = 5000.0   # Much higher penalty for self-destruction
        
        if self.success:
            reward += HIT_BONUS
        else:
            if terminated or truncated:
                # Explicit crash penalty (teaches "don't kill yourself")
                if event == "defense_ground":
                    reward -= CRASH_PENALTY
                
                # Use closest approach to grade misses (smoothly)
                # This keeps the agent caring about "almost hit" vs "missed by 500km"
                miss = float(self.min_dist)
                # Convert miss meters into a bounded penalty
                reward -= min(FAIL_PENALTY, miss / 50.0)  # 50m miss -> -1, 100km -> -2000 cap
        
        # Compute v_close and t_go_raw for debugging
        v_close_debug = float(-np.dot(rhat, vrel))  # = -d_dot, positive when closing
        t_go_raw = float(dist_after / (abs(v_close_debug) + 1e-6))
        
        info = {
            "dist": self.relative_distances[-1] if self.relative_distances else float('inf'),
            "event": event,
            "t": self.t,
            
            # --- DEBUG SIGNALS ---
            "dist_before": dist_before,
            "dist_after": dist_after,
            "min_dist": float(self.min_dist),
            "d_dot": d_dot,                 # range-rate (m/s), negative means closing
            "v_close": v_close_debug,        # positive when closing
            "t_go_raw": t_go_raw,           # raw time-to-go estimate
            "r_progress": float(r_progress),
            "r_close": float(r_close),
            "reward_total": float(reward),
            "action_mag": float(np.linalg.norm(action)),
        }
        
        # Add episode-summed reward components and physical metrics at terminal state
        if terminated or truncated:
            info["sum_r_progress"] = float(self.sum_r_progress)
            info["sum_r_close"] = float(self.sum_r_close)
            
            # --- PHYSICAL CONTROL METRICS (for steering diagnosis) ---
            info["v_close_raw"] = float(v_close_debug)  # m/s, positive when closing
            info["def_z"] = float(self.defense_pos[2])  # altitude (m)
            info["def_vz"] = float(self.defense_vel[2])  # vertical velocity (m/s)
            info["a_norm"] = float(np.linalg.norm(self.a_actual) / (self.a_max + 1e-9))  # normalized accel [0,1]
        
        return obs, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self.np_random = np.random.RandomState(seed)
        
        self.done = False
        self.success = False
        self.t = 0.0
        
        self.generate_enemy_missile()
        self.generate_defense_missile()
        
        self.defense_vel = np.array([
            self.defense_initial_velocity * np.cos(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_azimuth) * np.cos(self.defense_theta),
            self.defense_initial_velocity * np.sin(self.defense_theta)
        ], dtype=np.float32)
        
        self.enemy_vel = np.array([
            self.enemy_initial_velocity * np.cos(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_azimuth) * np.cos(self.enemy_theta),
            self.enemy_initial_velocity * np.sin(self.enemy_theta)
        ], dtype=np.float32)
        
        self.a_actual = np.zeros(3, dtype=np.float32)
        self.a_cmd_prev = np.zeros(3, dtype=np.float32)
        
        self.defense_pos = np.array([self.defense_x, self.defense_y, self.defense_z], dtype=np.float32)
        self.enemy_pos = np.array([self.enemy_x, self.enemy_y, self.enemy_z], dtype=np.float32)
        
        self.enemy_path = [self.enemy_pos.copy()]
        self.defense_path = [self.defense_pos.copy()]
        self.relative_distances = [float(np.linalg.norm(self.enemy_pos - self.defense_pos))]
        self.times = [self.t]
        
        # Track closest approach for better miss grading
        self.min_dist = float(self.relative_distances[-1])
        
        # Track reward components for debugging
        self.sum_r_progress = 0.0
        self.sum_r_close = 0.0
        
        return self._get_obs(), {}

    def render(self):
        enemy_path_array = np.array(self.enemy_path)
        defense_path_array = np.array(self.defense_path)
        
        xe, ye, ze = enemy_path_array[:, 0], enemy_path_array[:, 1], enemy_path_array[:, 2]
        xd, yd, zd = defense_path_array[:, 0], defense_path_array[:, 1], defense_path_array[:, 2]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(xe, ye, ze, label="Enemy missile trajectory", color='blue')
        ax.plot(xd, yd, zd, label="Defense missile trajectory", color='red')
        ax.scatter(xe[0], ye[0], ze[0], color='blue', s=50, label="Enemy Start")
        ax.scatter(xd[0], yd[0], zd[0], color='red', s=50, label="Defense Start")
        ax.set_xlabel('X coordinate (m)')
        ax.set_ylabel('Y coordinate (m)')
        ax.set_zlabel('Z coordinate (m)')
        ax.set_title('Missile Trajectory')
        ax.legend()
        plt.show()




In [11]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from collections import deque, Counter
import os
import csv

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback

from IPython.display import HTML, display


# ==========================================
# YOUR ANIMATION FUNCTIONS (must be in scope)
# ==========================================
def update_paths(num, xe, ye, ze, xd, yd, zd, lines, ax):
    lines[0].set_data_3d(xe[:num], ye[:num], ze[:num])
    lines[1].set_data_3d(xd[:num], yd[:num], zd[:num])
    ax.view_init(elev=20, azim=-60 + (num * 0.2))
    return lines

def animate_trajectories(enemy_path, defense_path):
    enemy_path_array = np.array(enemy_path)
    defense_path_array = np.array(defense_path)

    xe, ye, ze = enemy_path_array[:, 0], enemy_path_array[:, 1], enemy_path_array[:, 2]
    xd, yd, zd = defense_path_array[:, 0], defense_path_array[:, 1], defense_path_array[:, 2]

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    all_x = np.concatenate((xe, xd))
    all_y = np.concatenate((ye, yd))
    all_z = np.concatenate((ze, zd))

    # avoid identical min/max which can crash set_xlim
    eps = 1e-6
    ax.set_xlim([np.min(all_x) - eps, np.max(all_x) + eps])
    ax.set_ylim([np.min(all_y) - eps, np.max(all_y) + eps])
    ax.set_zlim([0, max(eps, np.max(all_z) + eps)])

    line_enemy, = ax.plot([], [], [], 'b-', linewidth=2, label="Enemy (Ballistic)")
    line_defense, = ax.plot([], [], [], 'r-', linewidth=2, label="Defense (Interceptor)")

    ax.scatter(xe[0], ye[0], ze[0], color='blue', s=50, marker='o')
    ax.scatter(xd[0], yd[0], zd[0], color='red', s=50, marker='^')

    ax.set_xlabel('X (m)')
    ax.set_ylabel('Y (m)')
    ax.set_zlabel('Altitude (m)')
    ax.set_title('Interception Simulation')
    ax.legend()

    total_steps = len(xe)
    step_size = max(1, total_steps // 200)
    frames = range(0, total_steps, step_size)

    ani = animation.FuncAnimation(
        fig,
        update_paths,
        frames=len(frames),
        fargs=(xe[::step_size], ye[::step_size], ze[::step_size],
               xd[::step_size], yd[::step_size], zd[::step_size],
               [line_enemy, line_defense], ax),
        interval=30,
        blit=False
    )

    plt.close(fig)
    return HTML(ani.to_html5_video())


# ==========================================
# SB3 env factory
# ==========================================
def make_env(seed=0):
    def _init():
        env = missile_interception_3d()
        env = Monitor(env)          # episode stats for SB3
        env.reset(seed=seed)
        return env
    return _init


# ==========================================
# CSV Logger Callback (with PPO internals + physical metrics)
# ==========================================
class CSVLoggerPlusCallback(BaseCallback):
    """
    Writes per-episode metrics to CSV including:
    - Episode metrics (reward, event, distances, shaping)
    - PPO health metrics (KL, clip_fraction, explained_variance, losses)
    - Physical control metrics (v_close, altitude, velocity, acceleration)
    
    It expects your env to put these keys in info at terminal:
      - event, dist, min_dist
      - sum_r_progress, sum_r_close
      - v_close_raw, def_z, def_vz, a_norm (physical metrics)
    And Monitor puts info["episode"] with r and l.
    """
    def __init__(
        self,
        csv_path: str = "./logs/train_plus.csv",
        flush_every_episodes: int = 10,
        flush_every_steps: int = 50_000,
        verbose: int = 0,
    ):
        super().__init__(verbose)
        self.csv_path = csv_path
        self.flush_every_episodes = int(flush_every_episodes)
        self.flush_every_steps = int(flush_every_steps)

        self._rows_buffer = []
        self._episode_count = 0
        self._next_step_flush = self.flush_every_steps

        # Episode metrics + PPO health + Physical control
        self._fieldnames = [
            # Episode basics
            "timesteps", "episode_idx", "ep_reward", "ep_len", "event",
            "final_dist", "min_dist",
            # Shaping rewards
            "sum_r_progress", "sum_r_close",
            # PPO health (may be NaN if not available at that moment)
            "approx_kl", "clip_fraction", "entropy_loss", "explained_variance",
            "value_loss", "policy_gradient_loss",
            # Physical control metrics
            "v_close_raw", "def_z", "def_vz", "a_norm",
        ]

    def _ensure_dir(self):
        d = os.path.dirname(self.csv_path)
        if d:
            os.makedirs(d, exist_ok=True)

    def _write_header_if_needed(self):
        if not os.path.exists(self.csv_path) or os.path.getsize(self.csv_path) == 0:
            with open(self.csv_path, "w", newline="") as f:
                csv.DictWriter(f, fieldnames=self._fieldnames).writeheader()

    def _flush(self):
        if not self._rows_buffer:
            return
        self._ensure_dir()
        self._write_header_if_needed()
        with open(self.csv_path, "a", newline="") as f:
            csv.DictWriter(f, fieldnames=self._fieldnames).writerows(self._rows_buffer)
        self._rows_buffer.clear()

    def _get_ppo_scalars(self):
        """Try to pull most recent scalars from SB3 logger."""
        nv = getattr(self.model, "logger", None)
        if nv is None:
            return {}
        name_to_value = getattr(nv, "name_to_value", {})
        
        def get(key):
            v = name_to_value.get(key, np.nan)
            try:
                return float(v)
            except Exception:
                return np.nan

        return {
            "approx_kl": get("train/approx_kl"),
            "clip_fraction": get("train/clip_fraction"),
            "entropy_loss": get("train/entropy_loss"),
            "explained_variance": get("train/explained_variance"),
            "value_loss": get("train/value_loss"),
            "policy_gradient_loss": get("train/policy_gradient_loss"),
        }

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        ppo = self._get_ppo_scalars()

        for info in infos:
            if "episode" in info:
                self._episode_count += 1

                row = {
                    # Episode basics
                    "timesteps": int(self.num_timesteps),
                    "episode_idx": int(self._episode_count),
                    "ep_reward": float(info["episode"].get("r", np.nan)),
                    "ep_len": int(info["episode"].get("l", -1)),
                    "event": str(info.get("event", "unknown")),
                    "final_dist": float(info.get("dist", np.nan)),
                    "min_dist": float(info.get("min_dist", info.get("dist", np.nan))),
                    # Shaping rewards
                    "sum_r_progress": float(info.get("sum_r_progress", np.nan)),
                    "sum_r_close": float(info.get("sum_r_close", np.nan)),
                    # PPO health
                    **ppo,
                    # Physical control metrics
                    "v_close_raw": float(info.get("v_close_raw", np.nan)),
                    "def_z": float(info.get("def_z", np.nan)),
                    "def_vz": float(info.get("def_vz", np.nan)),
                    "a_norm": float(info.get("a_norm", np.nan)),
                }
                self._rows_buffer.append(row)

                if self._episode_count % self.flush_every_episodes == 0:
                    self._flush()

        # Also flush by timesteps so you still get data even if episodes are long
        if self.num_timesteps >= self._next_step_flush:
            self._flush()
            self._next_step_flush += self.flush_every_steps

        return True

    def _on_training_end(self) -> None:
        self._flush()


# ==========================================
# Eval + animation callback
# ==========================================
class EvalAnimateCallback(BaseCallback):
    def __init__(self, every_steps: int, eval_seed: int = 123, verbose: int = 1, window_episodes: int = 50):
        super().__init__(verbose)
        self.every_steps = every_steps
        self.eval_seed = eval_seed
        self.next_trigger = every_steps

        # rolling window
        self.window_episodes = window_episodes
        self.events = deque(maxlen=window_episodes)
        self.ep_rewards = deque(maxlen=window_episodes)
        self.ep_lens = deque(maxlen=window_episodes)
        self.ep_min_dists = deque(maxlen=window_episodes)

    def _run_eval_episode(self):
        env = missile_interception_3d()
        obs, _ = env.reset(seed=self.eval_seed)

        done = False
        ep_rew = 0.0
        steps = 0
        last_info = {"event": "running"}

        while not done:
            action, _ = self.model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            ep_rew += float(reward)
            steps += 1
            last_info = info

            if steps > 10000:
                last_info = {"event": "eval_step_cap"}
                break

        dist = float(np.linalg.norm(env.enemy_pos - env.defense_pos))
        return env, ep_rew, steps, last_info.get("event", "unknown"), dist

    def _print_window(self):
        if len(self.events) == 0:
            return
        c = Counter(self.events)

        hit_rate = c.get("hit", 0) / len(self.events)
        avg_rew = float(np.mean(self.ep_rewards))
        avg_len = float(np.mean(self.ep_lens))
        avg_min = float(np.mean(self.ep_min_dists))

        # Useful: also show median min_dist (robust to outliers)
        med_min = float(np.median(self.ep_min_dists))

        print(
            f"[LAST {len(self.events)} EP] "
            f"hit_rate={hit_rate:.2f} "
            f"avg_rew={avg_rew:.1f} "
            f"avg_len={avg_len:.0f} "
            f"avg_min_dist={avg_min:.1f}m "
            f"med_min_dist={med_min:.1f}m "
            f"events={dict(c)}"
        )

    def _on_step(self) -> bool:
        # ---- TRAIN EPISODES END LOGGING ----
        for info in self.locals.get("infos", []):
            if "episode" in info:
                # terminal state for one env
                ev = info.get("event", "unknown")
                self.events.append(ev)
                self.ep_rewards.append(float(info["episode"]["r"]))
                self.ep_lens.append(int(info["episode"]["l"]))
                self.ep_min_dists.append(float(info.get("min_dist", info.get("dist", np.inf))))

                print(
                    f"[TRAIN EP END] "
                    f"reward={info['episode']['r']:.2f} "
                    f"len={info['episode']['l']} "
                    f"event={ev} "
                    f"final_dist={info.get('dist', -1):.1f}m "
                    f"min_dist={info.get('min_dist', -1):.1f}m"
                )
                
                # Print shaping reward sums if available
                if "sum_r_progress" in info:
                    print(
                        f"  shaping: r_progress={info.get('sum_r_progress', 0):.2f} "
                        f"r_close={info.get('sum_r_close', 0):.2f}"
                    )

                # print rolling summary occasionally (every 10 episodes)
                if len(self.events) % 10 == 0:
                    self._print_window()

        # ---- EVAL + ANIMATION ----
        if self.num_timesteps >= self.next_trigger:
            env, ep_rew, steps, event, dist = self._run_eval_episode()

            print(
                f"\n[Eval @ {self.num_timesteps} steps] "
                f"reward={ep_rew:.2f} "
                f"ep_len={steps} "
                f"event={event} "
                f"final_dist={dist:.1f}m "
                f"min_dist={getattr(env, 'min_dist', np.inf):.1f}m"
            )

            display(animate_trajectories(env.enemy_path, env.defense_path))
            self.next_trigger += self.every_steps

        return True


# ==========================================
# Train PPO
# ==========================================
SEED = 0
TOTAL_TIMESTEPS = 1_000_000

# 1e6 / 2e5 = 5 animations (good)
# 1e6 / 1e5 = 10 animations (more feedback but slower)
ANIMATE_EVERY = 100_000

venv = DummyVecEnv([make_env(SEED)])
venv = VecMonitor(venv)

model = PPO(
    "MlpPolicy",
    venv,
    verbose=1,
    seed=SEED,
    tensorboard_log="./tb_missile/",
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=256,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.0,
    vf_coef=0.5,
    max_grad_norm=0.5,
)

csv_plus = CSVLoggerPlusCallback(
    csv_path="./logs/train_plus.csv",
    flush_every_episodes=5,     # escribe cada 5 episodios
    flush_every_steps=25_000,   # o cada 25k pasos
)

callback = [csv_plus, EvalAnimateCallback(every_steps=ANIMATE_EVERY, eval_seed=123, verbose=1)]

model.learn(total_timesteps=TOTAL_TIMESTEPS, callback=callback)
model.save("ppo_missile_ufo")
print("Saved: ppo_missile_ufo.zip")


Using cpu device
Logging to ./tb_missile/PPO_8
[TRAIN EP END] reward=-4184.66 len=1660 event=enemy_ground final_dist=412465.1m min_dist=69165.5m
  shaping: r_progress=-2713.12 r_close=-865.68
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 1.66e+03  |
|    ep_rew_mean     | -4.18e+03 |
| time/              |           |
|    fps             | 284       |
|    iterations      | 1         |
|    time_elapsed    | 7         |
|    total_timesteps | 2048      |
----------------------------------
[TRAIN EP END] reward=-4455.00 len=2029 event=enemy_ground final_dist=513482.8m min_dist=49149.8m
  shaping: r_progress=-3374.48 r_close=-954.97
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1.84e+03     |
|    ep_rew_mean          | -4.32e+03    |
| time/                   |              |
|    fps                  | 326          |
|    iterations           | 2            |
|    time_el

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.3e+03     |
|    ep_rew_mean          | -5.1e+03    |
| time/                   |             |
|    fps                  | 331         |
|    iterations           | 49          |
|    time_elapsed         | 302         |
|    total_timesteps      | 100352      |
| train/                  |             |
|    approx_kl            | 0.003748306 |
|    clip_fraction        | 0.00571     |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.89       |
|    explained_variance   | 0.0834      |
|    learning_rate        | 0.0003      |
|    loss                 | 4.03e+04    |
|    n_updates            | 480         |
|    policy_gradient_loss | -0.00213    |
|    std                  | 1.03        |
|    value_loss           | 1.26e+05    |
-----------------------------------------
[TRAIN EP END] reward=-4173.85 len=1656 event=enemy_ground final_dist=437752

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1.16e+03     |
|    ep_rew_mean          | -5.06e+03    |
| time/                   |              |
|    fps                  | 349          |
|    iterations           | 98           |
|    time_elapsed         | 574          |
|    total_timesteps      | 200704       |
| train/                  |              |
|    approx_kl            | 0.0001348936 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -3.01        |
|    explained_variance   | 0.118        |
|    learning_rate        | 0.0003       |
|    loss                 | 1.74e+05     |
|    n_updates            | 970          |
|    policy_gradient_loss | -0.000252    |
|    std                  | 1.09         |
|    value_loss           | 3.02e+05     |
------------------------------------------
[TRAIN EP END] reward=-5837.96 len=1848 event=defense_

[TRAIN EP END] reward=-2159.71 len=1753 event=enemy_ground final_dist=333402.8m min_dist=5117.8m
  shaping: r_progress=-1971.09 r_close=-845.12
[LAST 50 EP] hit_rate=0.00 avg_rew=-2984.9 avg_len=1022 avg_min_dist=39217.4m med_min_dist=30135.0m events={'defense_ground': 25, 'enemy_ground': 25}
[TRAIN EP END] reward=-5446.86 len=264 event=defense_ground final_dist=67283.2m min_dist=67283.2m
  shaping: r_progress=873.65 r_close=254.15
[LAST 50 EP] hit_rate=0.00 avg_rew=-2997.9 avg_len=1020 avg_min_dist=39829.3m med_min_dist=30135.0m events={'enemy_ground': 25, 'defense_ground': 25}
[TRAIN EP END] reward=-5129.48 len=239 event=defense_ground final_dist=43524.0m min_dist=43524.0m
  shaping: r_progress=718.74 r_close=225.01
[LAST 50 EP] hit_rate=0.00 avg_rew=-3090.6 avg_len=1002 avg_min_dist=40181.1m med_min_dist=35059.3m events={'defense_ground': 26, 'enemy_ground': 24}
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 907     

[TRAIN EP END] reward=-5735.95 len=218 event=defense_ground final_dist=65807.7m min_dist=65807.7m
  shaping: r_progress=561.44 r_close=189.75
[LAST 50 EP] hit_rate=0.00 avg_rew=-4015.0 avg_len=1084 avg_min_dist=49766.8m med_min_dist=42850.3m events={'enemy_ground': 31, 'defense_ground': 19}
[TRAIN EP END] reward=-4800.48 len=233 event=defense_ground final_dist=21704.1m min_dist=17470.3m
  shaping: r_progress=533.55 r_close=156.17
[LAST 50 EP] hit_rate=0.00 avg_rew=-4049.0 avg_len=1063 avg_min_dist=49433.9m med_min_dist=42850.3m events={'enemy_ground': 30, 'defense_ground': 20}
[TRAIN EP END] reward=-4826.63 len=263 event=defense_ground final_dist=30257.3m min_dist=30201.2m
  shaping: r_progress=754.46 r_close=231.90
[LAST 50 EP] hit_rate=0.00 avg_rew=-4077.4 avg_len=1035 avg_min_dist=49774.4m med_min_dist=42850.3m events={'defense_ground': 21, 'enemy_ground': 29}
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.03e+03   |
|

[TRAIN EP END] reward=-3023.69 len=1503 event=enemy_ground final_dist=324259.2m min_dist=37562.7m
  shaping: r_progress=-2185.89 r_close=-850.49
[LAST 50 EP] hit_rate=0.00 avg_rew=-3876.7 avg_len=1311 avg_min_dist=46654.9m med_min_dist=41541.5m events={'defense_ground': 13, 'enemy_ground': 37}
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 1.23e+03      |
|    ep_rew_mean          | -3.84e+03     |
| time/                   |               |
|    fps                  | 369           |
|    iterations           | 245           |
|    time_elapsed         | 1358          |
|    total_timesteps      | 501760        |
| train/                  |               |
|    approx_kl            | 0.00035060078 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -3.01         |
|    explained_variance   | 0.22          |
|    learning_rate        | 0.0003        |
|

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.44e+03    |
|    ep_rew_mean          | -3.42e+03   |
| time/                   |             |
|    fps                  | 379         |
|    iterations           | 293         |
|    time_elapsed         | 1580        |
|    total_timesteps      | 600064      |
| train/                  |             |
|    approx_kl            | 0.012978803 |
|    clip_fraction        | 0.133       |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.01       |
|    explained_variance   | 0.957       |
|    learning_rate        | 0.0003      |
|    loss                 | 1.33e+03    |
|    n_updates            | 2920        |
|    policy_gradient_loss | -0.00307    |
|    std                  | 1.09        |
|    value_loss           | 1.2e+03     |
-----------------------------------------
[TRAIN EP END] reward=-2758.20 len=1809 event=enemy_ground final_dist=380840

[TRAIN EP END] reward=-2442.90 len=1081 event=enemy_ground final_dist=266558.9m min_dist=16882.1m
  shaping: r_progress=-2032.93 r_close=-712.47
[LAST 50 EP] hit_rate=0.00 avg_rew=-2502.3 avg_len=1501 avg_min_dist=20084.1m med_min_dist=17092.9m events={'enemy_ground': 50}
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1.53e+03     |
|    ep_rew_mean          | -2.74e+03    |
| time/                   |              |
|    fps                  | 383          |
|    iterations           | 342          |
|    time_elapsed         | 1828         |
|    total_timesteps      | 700416       |
| train/                  |              |
|    approx_kl            | 0.0064849546 |
|    clip_fraction        | 0.047        |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.85        |
|    explained_variance   | 0.986        |
|    learning_rate        | 0.0003       |
|    loss                 | 142        

[TRAIN EP END] reward=-1360.35 len=1380 event=enemy_ground final_dist=213301.5m min_dist=5899.2m
  shaping: r_progress=-1180.29 r_close=-606.89
[LAST 50 EP] hit_rate=0.00 avg_rew=-2431.1 avg_len=1599 avg_min_dist=18484.3m med_min_dist=15233.5m events={'enemy_ground': 50}
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.54e+03    |
|    ep_rew_mean          | -2.52e+03   |
| time/                   |             |
|    fps                  | 390         |
|    iterations           | 391         |
|    time_elapsed         | 2048        |
|    total_timesteps      | 800768      |
| train/                  |             |
|    approx_kl            | 0.006608465 |
|    clip_fraction        | 0.0439      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.78       |
|    explained_variance   | 0.993       |
|    learning_rate        | 0.0003      |
|    loss                 | 16.5        |
|    n_updates

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.63e+03    |
|    ep_rew_mean          | -1.7e+03    |
| time/                   |             |
|    fps                  | 396         |
|    iterations           | 440         |
|    time_elapsed         | 2271        |
|    total_timesteps      | 901120      |
| train/                  |             |
|    approx_kl            | 0.004468957 |
|    clip_fraction        | 0.0146      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.996       |
|    learning_rate        | 0.0003      |
|    loss                 | 22.9        |
|    n_updates            | 4390        |
|    policy_gradient_loss | -0.00193    |
|    std                  | 0.892       |
|    value_loss           | 76.5        |
-----------------------------------------
[TRAIN EP END] reward=-2985.17 len=2078 event=enemy_ground final_dist=328431

[TRAIN EP END] reward=-1906.92 len=1892 event=enemy_ground final_dist=313460.0m min_dist=29891.6m
  shaping: r_progress=-1251.47 r_close=-557.21
[LAST 50 EP] hit_rate=0.00 avg_rew=-1855.8 avg_len=1602 avg_min_dist=18413.1m med_min_dist=17378.6m events={'enemy_ground': 50}
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1.62e+03     |
|    ep_rew_mean          | -1.7e+03     |
| time/                   |              |
|    fps                  | 401          |
|    iterations           | 489          |
|    time_elapsed         | 2491         |
|    total_timesteps      | 1001472      |
| train/                  |              |
|    approx_kl            | 0.0037175203 |
|    clip_fraction        | 0.016        |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.46        |
|    explained_variance   | 0.996        |
|    learning_rate        | 0.0003       |
|    loss                 | 29.8       

In [None]:
# ==========================================
# Test ProNav Baseline (NO TRAINING)
# ==========================================
SEED = 0
NUM_EVAL_EPISODES = 100  # Test on 100 episodes to get statistics

def test_pronav_baseline():
    """Run ProNav on multiple episodes and collect statistics."""
    
    events = []
    rewards = []
    ep_lens = []
    min_dists = []
    
    for ep in range(NUM_EVAL_EPISODES):
        env = missile_interception_3d()
        obs, _ = env.reset(seed=SEED + ep)  # Different seed per episode
        
        done = False
        ep_reward = 0.0
        steps = 0
        
        while not done:
            # USE PRONAV INSTEAD OF MODEL
            action = env.calculate_pronav()
            
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            ep_reward += reward
            steps += 1
            
            if steps > 10000:
                info["event"] = "timeout"
                break
        
        # Collect stats
        event = info.get("event", "unknown")
        final_dist = float(np.linalg.norm(env.enemy_pos - env.defense_pos))
        min_dist = getattr(env, "min_dist", final_dist)
        
        events.append(event)
        rewards.append(ep_reward)
        ep_lens.append(steps)
        min_dists.append(min_dist)
        
        print(
            f"[ProNav EP {ep+1}/{NUM_EVAL_EPISODES}] "
            f"reward={ep_reward:.2f} "
            f"len={steps} "
            f"event={event} "
            f"final_dist={final_dist:.1f}m "
            f"min_dist={min_dist:.1f}m"
        )
        
        # Animate first 5 episodes
        if ep < 5:
            display(animate_trajectories(env.enemy_path, env.defense_path))
    
    # Print summary statistics
    c = Counter(events)
    hit_rate = c.get("hit", 0) / NUM_EVAL_EPISODES
    
    print("\n" + "="*60)
    print(f"PRONAV BASELINE RESULTS ({NUM_EVAL_EPISODES} episodes)")
    print("="*60)
    print(f"Hit Rate: {hit_rate:.2%}")
    print(f"Avg Reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
    print(f"Avg Episode Length: {np.mean(ep_lens):.1f} ± {np.std(ep_lens):.1f}")
    print(f"Avg Min Distance: {np.mean(min_dists):.1f}m (median: {np.median(min_dists):.1f}m)")
    print(f"Event Distribution: {dict(c)}")
    print("="*60)

# Run the baseline test
test_pronav_baseline()