In [None]:
# train_sac_discrete_full.py
"""
Soft Actor–Critic (Discrete) training script for Flight Landing.
- SAME structure as original DQN version
- Per-episode runways, success probability
- Summaries every 50 episodes
- SAC: categorical actor, twin critics, soft target updates, auto-entropy tuning
"""

import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces

# =====================================================
# 1️⃣ ENVIRONMENT (unchanged)
# =====================================================
class FlightEnv(gym.Env):
    """Simplified Flight Landing Environment with reachable start & runway conditions."""
    def __init__(self, start_alt=400.0, start_dist=800.0):
        super().__init__()
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, -30, 0], dtype=np.float32),
            high=np.array([5000, 300, 10000, 30, 1], dtype=np.float32),
            dtype=np.float32
        )
        self.action_space = spaces.Discrete(5)
        self.start_alt = start_alt
        self.start_dist = start_dist
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.altitude = float(self.start_alt)
        self.speed = float(160.0 + np.random.uniform(-10, 10))
        self.distance = float(self.start_dist)
        self.prev_distance = self.distance
        self.angle = float(np.random.uniform(-2, 2))
        self.runway_condition = float(np.random.choice([0.0, 0.5, 1.0]))
        self.steps = 0
        return self._get_obs(), {}

    def step(self, action):
        self.steps += 1

        # --- Action effects ---
        if action == 0:  # throttle up
            self.speed += 6.0
        elif action == 1:  # throttle down
            self.speed -= 6.0
        elif action == 2:  # pitch up
            self.altitude += 35.0
            self.angle += 1.5
        elif action == 3:  # pitch down
            self.altitude -= 35.0
            self.angle -= 1.5

        # --- Dynamics ---
        self.distance -= max(self.speed * 0.3, 1.0)
        self.altitude -= 8.0

        # --- Runway condition physics ---
        if self.runway_condition == 0.0:
            drag_factor = 0.6
        elif self.runway_condition == 0.5:
            drag_factor = 0.4
        else:
            drag_factor = 0.25

        self.speed -= drag_factor
        self.angle = np.clip(self.angle, -30, 30)
        self.altitude = max(self.altitude, 0.0)
        self.speed = np.clip(self.speed, 0.0, 300.0)

        # --- Reward shaping ---
        reward = 0.0
        reward += (self.prev_distance - self.distance) * 0.02
        self.prev_distance = self.distance

        reward -= 0.03
        reward -= 0.005 * abs(self.altitude - 100)
        reward -= 0.005 * abs(self.speed - 150)
        reward -= 0.01 * abs(self.angle)

        if self.distance < 400:
            reward += 0.8
        if 0 < self.altitude < 100 and 100 < self.speed < 200:
            reward += 1.5

        reward += (self.start_dist - self.distance) / max(1.0, self.start_dist)

        done = False
        success = False
        outcome = "in-flight"

        # Landing logic
        if self.distance <= 0:
            if 0 <= self.altitude <= 50 and 100 <= self.speed <= 200 and abs(self.angle) < 10:
                reward += 200.0
                success = True
                outcome = "successful landing"
            else:
                reward -= 40.0
                outcome = "failed landing"
            done = True

        # Crash / stall / timeout
        if self.altitude <= 0 and self.distance > 0:
            reward -= 40.0
            done = True
            outcome = "crash before runway"
        if self.speed <= 20 and self.altitude > 100:
            reward -= 40.0
            done = True
            outcome = "stall midair"
        if self.steps >= 600:
            done = True
            outcome = "timeout"

        info = {
            "success": success,
            "outcome": outcome,
            "runway_condition": self.runway_condition
        }
        return self._get_obs(), float(reward), bool(done), False, info

    def _get_obs(self):
        return np.array([
            self.altitude / 5000.0,
            self.speed / 300.0,
            self.distance / 10000.0,
            (self.angle + 30.0) / 60.0,
            self.runway_condition
        ], dtype=np.float32)


# =====================================================
# 2️⃣ REPLAY BUFFER (unchanged)
# =====================================================
class ReplayBuffer:
    def __init__(self, capacity: int, obs_shape):
        self.capacity = int(capacity)
        self.obs_shape = tuple(obs_shape)
        self.ptr = 0
        self.size = 0
        self.states = np.zeros((self.capacity,) + self.obs_shape, dtype=np.float32)
        self.next_states = np.zeros((self.capacity,) + self.obs_shape, dtype=np.float32)
        self.actions = np.zeros((self.capacity,), dtype=np.int64)
        self.rewards = np.zeros((self.capacity,), dtype=np.float32)
        self.dones = np.zeros((self.capacity,), dtype=np.float32)

    def push(self, state, action, reward, next_state, done):
        self.states[self.ptr] = state
        self.next_states[self.ptr] = next_state
        self.actions[self.ptr] = int(action)
        self.rewards[self.ptr] = float(reward)
        self.dones[self.ptr] = 1.0 if done else 0.0
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(
            states=self.states[idxs],
            actions=self.actions[idxs],
            rewards=self.rewards[idxs],
            next_states=self.next_states[idxs],
            dones=self.dones[idxs]
        )

    def __len__(self):
        return self.size


# =====================================================
# 3️⃣ SAC AGENT (Discrete)
# =====================================================
class MLP(nn.Module):
    def __init__(self, inp, out, hidden=(256, 256)):
        super().__init__()
        layers = []
        last = inp
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU()]
            last = h
        layers.append(nn.Linear(last, out))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class Actor(nn.Module):
    """Categorical policy for discrete SAC."""
    def __init__(self, obs_dim, n_actions):
        super().__init__()
        self.net = MLP(obs_dim, n_actions)

    def forward(self, obs):
        return self.net(obs)  # logits


class Critic(nn.Module):
    def __init__(self, obs_dim, n_actions):
        super().__init__()
        self.net = MLP(obs_dim, n_actions)

    def forward(self, obs):
        return self.net(obs)  # Q-values


class SACDiscrete:
    def __init__(self, obs_dim, n_actions, lr=3e-4, gamma=0.99,
                 tau=0.005, alpha=0.1, device='cpu'):

        self.device = torch.device(device)
        self.gamma = gamma
        self.tau = tau
        self.n_actions = n_actions

        # Actor
        self.actor = Actor(obs_dim, n_actions).to(self.device)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr)

        # Critics
        self.q1 = Critic(obs_dim, n_actions).to(self.device)
        self.q2 = Critic(obs_dim, n_actions).to(self.device)
        self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)

        # Target critics
        self.q1_t = Critic(obs_dim, n_actions).to(self.device)
        self.q2_t = Critic(obs_dim, n_actions).to(self.device)
        self.q1_t.load_state_dict(self.q1.state_dict())
        self.q2_t.load_state_dict(self.q2.state_dict())

        # Entropy temperature
        self.log_alpha = torch.tensor(np.log(alpha), requires_grad=True, device=self.device)
        self.alpha_opt = optim.Adam([self.log_alpha], lr=lr)
        self.target_entropy = -np.log(1.0 / n_actions) * 0.98

        self.mse = nn.MSELoss()

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs):
        t = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
        logits = self.actor(t)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        return int(dist.sample().item())

    def update(self, batch):
        s = torch.tensor(batch["states"], device=self.device, dtype=torch.float32)
        ns = torch.tensor(batch["next_states"], device=self.device, dtype=torch.float32)
        a = torch.tensor(batch["actions"], device=self.device, dtype=torch.int64)
        r = torch.tensor(batch["rewards"], device=self.device, dtype=torch.float32)
        d = torch.tensor(batch["dones"], device=self.device, dtype=torch.float32)

        # ---------- Compute target Q ----------
        with torch.no_grad():
            next_logits = self.actor(ns)
            next_probs = torch.softmax(next_logits, dim=-1)
            next_logp = torch.log(next_probs + 1e-12)

            q1_t = self.q1_t(ns)
            q2_t = self.q2_t(ns)
            q_min = torch.min(q1_t, q2_t)

            v_next = (next_probs * (q_min - self.alpha.detach() * next_logp)).sum(dim=1)

            target_q = r + (1 - d) * self.gamma * v_next

        # ---------- Critic losses ----------
        q1_vals = self.q1(s).gather(1, a.unsqueeze(1)).squeeze(1)
        q2_vals = self.q2(s).gather(1, a.unsqueeze(1)).squeeze(1)

        loss_q1 = self.mse(q1_vals, target_q)
        loss_q2 = self.mse(q2_vals, target_q)

        self.q1_opt.zero_grad()
        loss_q1.backward()
        nn.utils.clip_grad_norm_(self.q1.parameters(), 10.0)
        self.q1_opt.step()

        self.q2_opt.zero_grad()
        loss_q2.backward()
        nn.utils.clip_grad_norm_(self.q2.parameters(), 10.0)
        self.q2_opt.step()

        # ---------- Actor loss ----------
        logits = self.actor(s)
        probs = torch.softmax(logits, dim=-1)
        logp = torch.log(probs + 1e-12)
        q1_vals = self.q1(s)
        q2_vals = self.q2(s)
        q_min = torch.min(q1_vals, q2_vals)

        actor_loss = (probs * (self.alpha.detach() * logp - q_min)).sum(dim=1).mean()

        self.actor_opt.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_((self.actor.parameters()), 10.0)
        self.actor_opt.step()

        # ---------- Alpha (entropy temperature) update ----------
        entropy = -(probs * logp).sum(dim=1).mean()
        alpha_loss = -(self.log_alpha * (entropy - self.target_entropy).detach())

        self.alpha_opt.zero_grad()
        alpha_loss.backward()
        self.alpha_opt.step()

        # ---------- Soft target update ----------
        with torch.no_grad():
            for p, tp in zip(self.q1.parameters(), self.q1_t.parameters()):
                tp.data.mul_(1 - self.tau)
                tp.data.add_(self.tau * p.data)
            for p, tp in zip(self.q2.parameters(), self.q2_t.parameters()):
                tp.data.mul_(1 - self.tau)
                tp.data.add_(self.tau * p.data)


# =====================================================
# 4️⃣ TRAINING (same structure as DQN version)
# =====================================================
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--episodes', type=int, default=30000)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--buffer_size', type=int, default=30000)
    p.add_argument('--batch_size', type=int, default=128)
    p.add_argument('--save_dir', type=str, default='./checkpoints')
    p.add_argument('--cpu', action='store_true')
    p.add_argument('--seed', type=int, default=42)
    return p.parse_args()


def train(args):
    device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'
    torch.manual_seed(args.seed); np.random.seed(args.seed); random.seed(args.seed)

    env = FlightEnv()
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = SACDiscrete(obs_dim, n_actions, lr=args.lr, device=device)
    buffer = ReplayBuffer(args.buffer_size, (obs_dim,))

    os.makedirs(args.save_dir, exist_ok=True)

    outcomes = {"successful landing": 0, "failed landing": 0,
                "crash before runway": 0, "stall midair": 0, "timeout": 0}
    runway_counts = {0.0: 0, 0.5: 0, 1.0: 0}
    runway_success = {0.0: 0, 0.5: 0, 1.0: 0}

    best_return = -1e9

    for ep in range(1, args.episodes + 1):
        obs, _ = env.reset()
        ep_return = 0.0
        done = False
        outcome = None
        runway = env.runway_condition

        while not done:
            action = agent.act(obs)
            nobs, rew, done, _, info = env.step(action)

            buffer.push(obs, action, rew, nobs, done)
            obs = nobs
            ep_return += rew
            outcome = info["outcome"]

            if len(buffer) > args.batch_size:
                batch = buffer.sample(args.batch_size)
                agent.update(batch)

        # Aggregate statistics
        outcomes[outcome] += 1
        runway_counts[runway] += 1
        if info["success"]:
            runway_success[runway] += 1

        runway_probs = {
            k: (runway_success[k] / runway_counts[k] if runway_counts[k] > 0 else 0)
            for k in runway_counts
        }

        print(f"Ep {ep}/{args.episodes} | Runway: {runway} | Return: {ep_return:.2f} "
              f"| Outcome: {outcome} | Success Prob: {runway_probs[runway]:.2f}")

        if ep_return > best_return:
            best_return = ep_return
            torch.save(agent.actor.state_dict(), os.path.join(args.save_dir, "best_actor.pt"))
            torch.save(agent.q1.state_dict(), os.path.join(args.save_dir, "best_q1.pt"))
            torch.save(agent.q2.state_dict(), os.path.join(args.save_dir, "best_q2.pt"))

        if ep % 5000 == 0:
            print("\n--- Outcome Summary up to Episode", ep, "---")
            for k, v in outcomes.items():
                print(f"{k:<25}: {v}")
            print("Runway success probabilities:", runway_probs)
            print("------------------------------------------\n")

    print("✅ Training done.")
    print("Best episode return:", best_return)
    print("Final outcomes summary:", outcomes)
    print("Final runway success probabilities:", runway_probs)


# =====================================================
# 5️⃣ RUN
# =====================================================
if __name__ == "__main__":
    import sys
    sys.argv = [sys.argv[0]]
    args = parse_args()
    print("Starting SAC-Discrete training with args:", args)
    train(args)
