In [None]:
# train_reinforce.py
"""
REINFORCE (Monte-Carlo Policy Gradient) training script for Flight Landing.
- Self-contained FlightEnv (5-dim observation, 5 discrete actions).
- Prints per-episode progress and runway type/outcome.
- Tracks per-runway success probabilities.
- Saves best policy.
- Tweaked dynamics & rewards to improve successful landing frequency.
"""

import argparse
import os
import random
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces

# -----------------------
# Environment
# -----------------------
class FlightEnv(gym.Env):
    def __init__(self, start_alt=300.0, start_dist=600.0):
        super().__init__()
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, -30, 0], dtype=np.float32),
            high=np.array([5000, 300, 10000, 30, 1], dtype=np.float32),
            dtype=np.float32
        )
        self.action_space = spaces.Discrete(5)
        self.start_alt = start_alt
        self.start_dist = start_dist
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.altitude = float(self.start_alt)
        self.speed = float(160.0 + np.random.uniform(-8, 8))
        self.distance = float(self.start_dist)
        self.prev_distance = self.distance
        self.angle = float(np.random.uniform(-2, 2))
        self.runway_condition = float(np.random.choice([0.0, 0.5, 1.0]))
        self.steps = 0
        return self._get_obs(), {}

    def step(self, action):
        self.steps += 1
        # Actions: 0 throttle+, 1 throttle-, 2 pitch+, 3 pitch-, 4 noop
        if action == 0:
            self.speed += 4.0   # gentler throttle increments
        elif action == 1:
            self.speed -= 4.0
        elif action == 2:
            self.altitude += 25.0
            self.angle += 1.2
        elif action == 3:
            self.altitude -= 25.0
            self.angle -= 1.2
        # else 4: maintain

        # --- Dynamics (gentler descent & drag) ---
        self.distance -= max(self.speed * 0.22, 1.0)  # slower approach per step
        self.altitude -= 5.0                          # gentler gravity
        self.speed -= 0.3                             # lighter drag
        self.angle = np.clip(self.angle, -30, 30)
        self.altitude = max(self.altitude, 0.0)
        self.speed = np.clip(self.speed, 0.0, 300.0)

        # --- Reward shaping (stronger, denser feedback) ---
        reward = 0.0
        # progress-based reward (larger)
        reward += (self.prev_distance - self.distance) * 0.05
        self.prev_distance = self.distance

        # small step penalty (keeps episodes efficient)
        reward -= 0.02

        # altitude and speed shaping (softer penalties)
        reward -= 0.003 * abs(self.altitude - 100)
        reward -= 0.003 * abs(self.speed - 150)

        # pitch penalty
        reward -= 0.008 * abs(self.angle)

        # near runway bonus (earlier)
        if self.distance < 500:
            reward += 1.0

        # bonus for ideal landing profile
        if 0 < self.altitude < 100 and 100 < self.speed < 200:
            reward += 2.0

        # proportional progress reward
        reward += 1.0 * ((self.start_dist - self.distance) / max(1.0, self.start_dist))

        done = False
        success = False
        outcome = "in-flight"

        # --- Landing event (much stronger landing reward) ---
        if self.distance <= 0:
            # success window widened a touch
            if 0 <= self.altitude <= 60 and 95 <= self.speed <= 210 and abs(self.angle) < 12:
                reward += 400.0  # boosted landing bonus to strongly reinforce correct touchdown
                success = True
                outcome = "successful landing"
            else:
                reward -= 40.0
                outcome = "failed landing"
            done = True

        # --- Crash / stall conditions ---
        if self.altitude <= 0 and self.distance > 0:
            reward -= 60.0
            done = True
            outcome = "crash before runway"

        if self.speed <= 20 and self.altitude > 100:
            reward -= 60.0
            done = True
            outcome = "stall midair"

        # Allow longer episodes to give the agent more time to correct
        if self.steps >= 800:
            done = True
            outcome = "timeout"

        return self._get_obs(), float(reward), bool(done), False, {"success": success, "outcome": outcome,
                                                                    "runway": self.runway_condition}

    def _get_obs(self):
        return np.array([
            self.altitude / 5000.0,
            self.speed / 300.0,
            self.distance / 10000.0,
            (self.angle + 30.0) / 60.0,
            self.runway_condition
        ], dtype=np.float32)


# -----------------------
# Policy network
# -----------------------
class PolicyNet(nn.Module):
    def __init__(self, obs_dim, n_actions, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions)
        )

    def forward(self, x):
        return self.net(x)


class ValueNet(nn.Module):
    def __init__(self, obs_dim, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)


# -----------------------
# REINFORCE Training
# -----------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--episodes', type=int, default=50000)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--gamma', type=float, default=0.99)
    p.add_argument('--hidden', type=int, default=256)
    p.add_argument('--use_baseline', action='store_true', help='Use value baseline to reduce variance')
    p.add_argument('--entropy_coef', type=float, default=0.0015)  # slightly larger to encourage exploration early
    p.add_argument('--save_dir', type=str, default='./checkpoints_reinforce')
    p.add_argument('--report_every', type=int, default=10)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--cpu', action='store_true')
    return p.parse_args()



def train(args):
    device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'
    torch.manual_seed(args.seed); np.random.seed(args.seed); random.seed(args.seed)

    env = FlightEnv()
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    policy = PolicyNet(obs_dim, n_actions, hidden=args.hidden).to(device)
    policy_opt = optim.Adam(policy.parameters(), lr=args.lr)

    value, value_opt = None, None
    if args.use_baseline:
        value = ValueNet(obs_dim).to(device)
        value_opt = optim.Adam(value.parameters(), lr=args.lr)

    os.makedirs(args.save_dir, exist_ok=True)

    success_counts = {0.0: 0, 0.5: 0, 1.0: 0}
    total_counts = {0.0: 0, 0.5: 0, 1.0: 0}
    outcomes = defaultdict(int)

    best_return = -1e9

    def finish_episode(episode_states, episode_actions, episode_rewards):
        R = 0.0
        returns = []
        for r in reversed(episode_rewards):
            R = r + args.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32, device=device)
        # normalize returns for stability
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        states_t = torch.tensor(np.array(episode_states), dtype=torch.float32, device=device)
        actions_t = torch.tensor(episode_actions, dtype=torch.int64, device=device)

        logits = policy(states_t)
        logp = torch.log_softmax(logits, dim=1)
        selected_logp = logp.gather(1, actions_t.unsqueeze(1)).squeeze(1)

        if args.use_baseline:
            values = value(states_t).detach()
            adv = returns - (values - values.mean())
        else:
            adv = returns

        policy_loss = - (selected_logp * adv).mean()
        probs = torch.softmax(logits, dim=1)
        entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1).mean()
        loss = policy_loss - args.entropy_coef * entropy

        policy_opt.zero_grad()
        loss.backward()
        policy_opt.step()

        if args.use_baseline:
            value_opt.zero_grad()
            v_loss = nn.MSELoss()(value(states_t), returns)
            v_loss.backward()
            value_opt.step()
            return float(loss.item()), float(v_loss.item())
        return float(loss.item()), None

    # -------------------------------------------------------
    # Training Loop
    # -------------------------------------------------------
    for ep in range(1, args.episodes + 1):
        obs, _ = env.reset()
        ep_states, ep_actions, ep_rewards = [], [], []
        ep_return = 0.0
        done = False
        runway = env.runway_condition
        outcome = "in-flight"

        print(f"Ep {ep}/{args.episodes} ...", end="", flush=True)

        while not done:
            state = obs.astype(np.float32)
            st_t = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                logits = policy(st_t)
                probs = torch.softmax(logits, dim=1).cpu().numpy().squeeze(0)
            # sample action using the policy distribution
            action = np.random.choice(n_actions, p=probs)

            next_obs, reward, terminated, truncated, info = env.step(action)
            done = bool(terminated or truncated)

            ep_states.append(state)
            ep_actions.append(action)
            ep_rewards.append(reward)

            obs = next_obs
            ep_return += reward
            outcome = info.get("outcome", outcome)

        pl_loss, v_loss = finish_episode(ep_states, ep_actions, ep_rewards)

        total_counts[runway] += 1
        if info.get("success", False):
            success_counts[runway] += 1
            outcomes["successful landing"] += 1
        else:
            outcomes[outcome] += 1

        loss_str = f" | PLoss: {pl_loss:.3f}"
        if v_loss is not None:
            loss_str += f" | VLoss: {v_loss:.3f}"

        print(f" Return: {ep_return:.2f} | Runway: {runway} | Outcome: {outcome}{loss_str}")

        if ep_return > best_return:
            best_return = ep_return
            torch.save(policy.state_dict(), os.path.join(args.save_dir, "reinforce_best_policy.pt"))
            if args.use_baseline:
                torch.save(value.state_dict(), os.path.join(args.save_dir, "reinforce_best_value.pt"))

        # -------------------------------------------------------
        # Outcome Summary Every 2000 Episodes
        # -------------------------------------------------------
        if ep % 2000 == 0:
            print(f"\n--- Outcome Summary up to Episode {ep} ---")
            print(f"successful landing       : {outcomes['successful landing']}")
            print(f"failed landing           : {outcomes['failed landing']}")
            print(f"crash before runway      : {outcomes['crash before runway']}")
            print(f"stall midair             : {outcomes['stall midair']}")
            print(f"timeout                  : {outcomes['timeout']}")

            succ_rates = {
                k: (success_counts[k] / total_counts[k] if total_counts[k] > 0 else 0.0)
                for k in success_counts
            }
            print(f"Runway success probabilities: {succ_rates}")
            print("------------------------------------------\n")

        # Existing periodic report (kept)
        if ep % args.report_every == 0:
            succ_rates = {k: (success_counts[k] / total_counts[k] if total_counts[k] > 0 else 0.0) for k in success_counts}
            print(f"  --- Report @ Ep {ep}: BestReturn {best_return:.2f} | SuccessRates: {succ_rates}")
            print("  Outcomes so far:", dict(outcomes))


    # Final save
    torch.save(policy.state_dict(), os.path.join(args.save_dir, "reinforce_final_policy.pt"))
    if args.use_baseline:
        torch.save(value.state_dict(), os.path.join(args.save_dir, "reinforce_final_value.pt"))

    final_rates = {k: (success_counts[k] / total_counts[k] if total_counts[k] > 0 else 0.0) for k in success_counts}
    print("Training finished. Final success rates by runway:", final_rates)
    print("Best episode return:", best_return)



if __name__ == "__main__":
    import sys
    sys.argv = [sys.argv[0]]
    args = parse_args()
    print("Starting REINFORCE training with args:", args)
    train(args)
