In [None]:
# train_ppo_full_fixed.py
"""
PPO Training Script for Flight Landing Environment (fixed bookkeeping).
- Handles unknown 'outcome' values like "in-flight" without KeyError.
- Ensures runway_condition keys are normalized to Python floats when used as dict keys.
"""

import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces

class FlightEnv(gym.Env):
    """Same FlightEnv as before"""
    def __init__(self, start_alt=400.0, start_dist=800.0):
        super().__init__()
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, -30, 0], dtype=np.float32),
            high=np.array([5000, 300, 10000, 30, 1], dtype=np.float32),
            dtype=np.float32
        )
        self.action_space = spaces.Discrete(5)
        self.start_alt = start_alt
        self.start_dist = start_dist
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.altitude = float(self.start_alt)
        self.speed = float(160.0 + np.random.uniform(-10, 10))
        self.distance = float(self.start_dist)
        self.prev_distance = self.distance
        self.angle = float(np.random.uniform(-2, 2))
        self.runway_condition = float(np.random.choice([0.0, 0.5, 1.0]))
        self.steps = 0
        return self._get_obs(), {}

    def step(self, action):
        self.steps += 1
        if action == 0:
            self.speed += 6.0
        elif action == 1:
            self.speed -= 6.0
        elif action == 2:
            self.altitude += 35.0
            self.angle += 1.5
        elif action == 3:
            self.altitude -= 35.0
            self.angle -= 1.5

        self.distance -= max(self.speed * 0.3, 1.0)
        self.altitude -= 8.0

        if self.runway_condition == 0.0:
            drag_factor = 0.6
        elif self.runway_condition == 0.5:
            drag_factor = 0.4
        else:
            drag_factor = 0.25

        self.speed -= drag_factor
        self.angle = np.clip(self.angle, -30, 30)
        self.altitude = max(self.altitude, 0.0)
        self.speed = np.clip(self.speed, 0.0, 300.0)

        reward = 0.0
        reward += (self.prev_distance - self.distance) * 0.02
        self.prev_distance = self.distance
        reward -= 0.03
        reward -= 0.005 * abs(self.altitude - 100)
        reward -= 0.005 * abs(self.speed - 150)
        reward -= 0.01 * abs(self.angle)

        if self.distance < 400:
            reward += 0.8
        if 0 < self.altitude < 100 and 100 < self.speed < 200:
            reward += 1.5
        reward += (self.start_dist - self.distance) / max(1.0, self.start_dist)

        done = False
        success = False
        outcome = "in-flight"

        if self.distance <= 0:
            if 0 <= self.altitude <= 50 and 100 <= self.speed <= 200 and abs(self.angle) < 10:
                reward += 200.0
                success = True
                outcome = "successful landing"
            else:
                reward -= 40.0
                outcome = "failed landing"
            done = True

        if self.altitude <= 0 and self.distance > 0:
            reward -= 40.0
            done = True
            outcome = "crash before runway"

        if self.speed <= 20 and self.altitude > 100:
            reward -= 40.0
            done = True
            outcome = "stall midair"

        if self.steps >= 600:
            done = True
            outcome = "timeout"

        info = {
            "success": success,
            "outcome": outcome,
            "runway_condition": self.runway_condition
        }

        return self._get_obs(), float(reward), bool(done), False, info

    def _get_obs(self):
        return np.array([
            self.altitude / 5000.0,
            self.speed / 300.0,
            self.distance / 10000.0,
            (self.angle + 30.0) / 60.0,
            self.runway_condition
        ], dtype=np.float32)


class ActorCritic(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU()
        )
        self.policy = nn.Linear(256, action_dim)
        self.value = nn.Linear(256, 1)

    def forward(self, x):
        x = self.shared(x)
        return self.policy(x), self.value(x)


class PPOAgent:
    def __init__(self, obs_dim, action_dim, lr=3e-4, gamma=0.99, lam=0.95, clip=0.2, device="cpu"):
        self.gamma = gamma
        self.lam = lam
        self.clip = clip
        self.device = torch.device(device)

        self.net = ActorCritic(obs_dim, action_dim).to(self.device)
        self.opt = optim.Adam(self.net.parameters(), lr=lr)

    def act(self, obs):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        logits, value = self.net(obs_t)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return int(action.item()), dist.log_prob(action).item(), float(value.item())

    def compute_advantages(self, rewards, values, dones, last_value):
        adv = 0
        advantages = []
        values = values + [last_value]
        for t in reversed(range(len(rewards))):
            td = rewards[t] + self.gamma * values[t + 1] * (1 - (1.0 if dones[t] else 0.0)) - values[t]
            adv = td + self.gamma * self.lam * (1 - (1.0 if dones[t] else 0.0)) * adv
            advantages.append(adv)
        return list(reversed(advantages))

    def update(self, batch, epochs=10, batch_size=64):
        states = torch.tensor(np.array(batch["states"]), dtype=torch.float32, device=self.device)
        actions = torch.tensor(batch["actions"], dtype=torch.int64, device=self.device)
        old_log_probs = torch.tensor(batch["log_probs"], dtype=torch.float32, device=self.device)
        returns = torch.tensor(batch["returns"], dtype=torch.float32, device=self.device)
        advantages = torch.tensor(batch["advantages"], dtype=torch.float32, device=self.device)

        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        dataset_size = len(states)
        for _ in range(epochs):
            idxs = np.random.permutation(dataset_size)
            for i in range(0, dataset_size, batch_size):
                inds = idxs[i:i + batch_size]
                logits, value = self.net(states[inds])
                probs = torch.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)

                log_p = dist.log_prob(actions[inds])
                ratio = torch.exp(log_p - old_log_probs[inds])

                surr1 = ratio * advantages[inds]
                surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * advantages[inds]
                policy_loss = -torch.min(surr1, surr2).mean()

                value_loss = nn.MSELoss()(value.squeeze(), returns[inds])
                entropy = dist.entropy().mean()
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

                self.opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
                self.opt.step()


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--episodes', type=int, default=50000)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--rollout_steps', type=int, default=2048)
    p.add_argument('--gamma', type=float, default=0.99)
    p.add_argument('--lam', type=float, default=0.95)
    p.add_argument('--clip', type=float, default=0.2)
    p.add_argument('--save_dir', type=str, default='./checkpoints')
    p.add_argument('--cpu', action='store_true')
    p.add_argument('--seed', type=int, default=42)
    return p.parse_args()


def train(args):
    device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'
    torch.manual_seed(args.seed); np.random.seed(args.seed); random.seed(args.seed)

    env = FlightEnv()
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = PPOAgent(obs_dim, action_dim, lr=args.lr, gamma=args.gamma,
                     lam=args.lam, clip=args.clip, device=device)

    os.makedirs(args.save_dir, exist_ok=True)

    # initialize with expected keys but allow unknown outcomes
    outcomes = {"successful landing": 0, "failed landing": 0,
                "crash before runway": 0, "stall midair": 0, "timeout": 0}
    runway_counts = {0.0: 0, 0.5: 0, 1.0: 0}
    runway_success = {0.0: 0, 0.5: 0, 1.0: 0}

    best_return = -1e9
    episode = 0

    while episode < args.episodes:
        rollout = {"states": [], "actions": [], "rewards": [],
                   "values": [], "log_probs": [], "dones": []}

        obs, _ = env.reset()
        last_info = {"outcome": "in-flight", "runway_condition": 0.0, "success": False}

        for _ in range(args.rollout_steps):
            action, logp, value = agent.act(obs)
            next_obs, reward, done, _, info = env.step(action)

            rollout["states"].append(obs)
            rollout["actions"].append(action)
            rollout["rewards"].append(reward)
            rollout["values"].append(value)
            rollout["log_probs"].append(logp)
            rollout["dones"].append(done)

            obs = next_obs
            last_info = info  # keep last info so we have something after rollout

            # If the env returns done we reset to continue filling the rollout,
            # but we also may want to record that episode outcome later.
            if done:
                obs, _ = env.reset()

        # compute last value for bootstrap
        with torch.no_grad():
            last_value = agent.net(
                torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            )[1].item()

        advantages = agent.compute_advantages(
            rollout["rewards"], rollout["values"], rollout["dones"], last_value
        )
        returns = [adv + val for adv, val in zip(advantages, rollout["values"])]

        batch = {
            "states": rollout["states"],
            "actions": rollout["actions"],
            "log_probs": rollout["log_probs"],
            "advantages": advantages,
            "returns": returns
        }

        agent.update(batch)

        # Robust bookkeeping: use last_info but guard unknown keys
        outcome = last_info.get("outcome", "in-flight")
        # ensure the key exists (no KeyError)
        outcomes.setdefault(outcome, 0)
        outcomes[outcome] += 1

        runway_val = float(last_info.get("runway_condition", 0.0))
        runway_counts.setdefault(runway_val, 0)
        runway_counts[runway_val] += 1

        if last_info.get("success", False):
            runway_success.setdefault(runway_val, 0)
            runway_success[runway_val] += 1

        episode += 1
        ep_return = sum(rollout["rewards"])
        if ep_return > best_return:
            best_return = ep_return
            torch.save(agent.net.state_dict(), os.path.join(args.save_dir, "best_model.pt"))

        if episode % 200 == 0:
            runway_probs = {
                k: (runway_success.get(k, 0) / runway_counts.get(k, 1) if runway_counts.get(k, 0) > 0 else 0)
                for k in sorted(runway_counts.keys())
            }
            print("\n--- Outcome Summary up to Episode", episode, "---")
            for k, v in outcomes.items():
                print(f"{k:<25}: {v}")
            print("Runway success probabilities:", runway_probs)
            print("------------------------------------------\n")

    print("Training complete.")
    print("Best Episode Return:", best_return)
    print("Final outcomes summary:")
    for k, v in outcomes.items():
        print(f"{k:<25}: {v}")


if __name__ == "__main__":
    import sys
    sys.argv = [sys.argv[0]]
    args = parse_args()
    print("Starting PPO training with args:", args)
    train(args)
