In [3]:
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from itertools import count

import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli

# =========================
# Tunable hyperparameters
# =========================
GAMMA = 0.99

# Memory / model capacity
ACTOR_HIDDEN = 256
CRITIC_HIDDEN = 256
LSTM_LAYERS = 1
LSTM_DROPOUT = 0.1

# Optimization stability
LR_ACTOR = 2e-4
LR_CRITIC = 2e-4
CLIP_NORM = 0.5

# Exploration / collapse prevention
ENT_COEF = 0.01
ADV_NORM = True
USE_HUBER_VALUE_LOSS = True

# =========================
# Catch & Climb (NEW)
# =========================
# baseline is an EMA of episode_reward; collapse/spike is judged vs previous baseline
BASELINE_BETA = 0.90          # EMA smoothing; bigger -> slower baseline
EPS = 1e-6

# collapse/spike intensity normalization
NORM_BY = "abs_baseline"      # or "maxlen" etc. we'll use abs baseline for stability

# Anchor (anti-collapse): keep best policy snapshot; on collapse, penalize KL(current || anchor)
AUX_KL_COEF = 0.5             # base KL weight
AUX_KL_MAX = 5.0              # cap when collapse is huge

# Spike (self-imitation): on spike, do extra log-prob maximize on that trajectory
SIL_COEF = 0.5                # base imitation weight
SIL_MAX = 5.0                 # cap

# Extra gradient steps when an event happens (collapse or spike)
EXTRA_UPDATES_ON_EVENT = 2     # do additional updates using same trajectory
EXTRA_UPDATES_CAP = 6          # safety cap if you later scale dynamically

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device =", device)


class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 64)
        self.lstm = nn.LSTM(
            64, ACTOR_HIDDEN,
            num_layers=LSTM_LAYERS,
            dropout=(LSTM_DROPOUT if LSTM_LAYERS >= 2 else 0.0),
            batch_first=True
        )
        self.fc2 = nn.Linear(ACTOR_HIDDEN, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, hidden):
        x = self.relu(self.fc1(x))
        x, hidden = self.lstm(x, hidden)
        x = self.relu(x)
        x = self.sigmoid(self.fc2(x))  # (B, T, 1)
        return x, hidden

    def select_action(self, state, hidden):
        with torch.no_grad():
            prob, hidden = self.forward(state, hidden)
            b = Bernoulli(prob)
            action = b.sample()
        return int(action.item()), hidden


class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 64)
        self.lstm = nn.LSTM(
            64, CRITIC_HIDDEN,
            num_layers=LSTM_LAYERS,
            dropout=(LSTM_DROPOUT if LSTM_LAYERS >= 2 else 0.0),
            batch_first=True
        )
        self.fc2 = nn.Linear(CRITIC_HIDDEN, 1)
        self.relu = nn.ReLU()

    def forward(self, x, hidden):
        x = self.relu(self.fc1(x))
        x, hidden = self.lstm(x, hidden)
        x = self.relu(x)
        x = self.fc2(x)
        return x, hidden


def obs_to_partial(obs):
    return np.array([obs[0], obs[2]], dtype=np.float32)


def bernoulli_kl(p, q, eps=1e-6):
    """
    KL(Ber(p) || Ber(q)) = p log(p/q) + (1-p) log((1-p)/(1-q))
    p,q: tensors in (0,1)
    """
    p = torch.clamp(p, eps, 1.0 - eps)
    q = torch.clamp(q, eps, 1.0 - eps)
    return p * torch.log(p / q) + (1.0 - p) * torch.log((1.0 - p) / (1.0 - q))


@torch.no_grad()
def forward_policy_probs(policy_net, states_tensor):
    a_hx = torch.zeros((LSTM_LAYERS, 1, ACTOR_HIDDEN), device=states_tensor.device)
    a_cx = torch.zeros((LSTM_LAYERS, 1, ACTOR_HIDDEN), device=states_tensor.device)
    prob, _ = policy_net(states_tensor, (a_hx, a_cx))  # (1,T,1)
    return prob.squeeze(0)  # (T,1)


def compute_climb_signals(ep_reward, baseline_prev):
    """
    collapse: ep_reward < baseline_prev
    spike   : ep_reward > baseline_prev

    intensity is normalized so scale is stable.
    """
    denom = abs(baseline_prev) + 1.0  # +1 to avoid tiny denom when baseline near 0
    collapse_raw = max(0.0, baseline_prev - ep_reward)
    spike_raw = max(0.0, ep_reward - baseline_prev)

    collapse_int = collapse_raw / (denom + EPS)
    spike_int = spike_raw / (denom + EPS)

    is_collapse = 1.0 if ep_reward < baseline_prev else 0.0
    is_spike = 1.0 if ep_reward > baseline_prev else 0.0
    return is_collapse, is_spike, collapse_int, spike_int


if __name__ == "__main__":
    env = gym.make("CartPole-v1")

    policy = PolicyNetwork().to(device)
    value = ValueNetwork().to(device)

    optim = torch.optim.Adam(policy.parameters(), lr=LR_ACTOR)
    value_optim = torch.optim.Adam(value.parameters(), lr=LR_CRITIC)

    writer = SummaryWriter("./lstm_logs_catch_climb")

    # Anchor snapshot (best policy so far)
    anchor_policy = PolicyNetwork().to(device)
    anchor_policy.load_state_dict(policy.state_dict())
    anchor_policy.eval()

    best_reward = -1e9
    baseline = 0.0  # EMA baseline

    for epoch in count():
        obs, info = env.reset(seed=None)
        state = obs_to_partial(obs)
        episode_reward = 0.0

        # actor hidden init
        a_hx = torch.zeros((LSTM_LAYERS, 1, ACTOR_HIDDEN), device=device)
        a_cx = torch.zeros((LSTM_LAYERS, 1, ACTOR_HIDDEN), device=device)

        rewards, actions, states = [], [], []

        for t in range(500):
            states.append(state.copy())

            state_t = torch.tensor(state, dtype=torch.float32, device=device).view(1, 1, 2)
            action, (a_hx, a_cx) = policy.select_action(state_t, (a_hx, a_cx))
            actions.append(action)

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            next_state = obs_to_partial(next_obs)
            episode_reward += float(reward)

            rewards.append(float(reward))
            state = next_state

            if done:
                break

        # -------------------------
        # returns
        # -------------------------
        returns = np.zeros(len(rewards), dtype=np.float32)
        R = 0.0
        for i in reversed(range(len(rewards))):
            R = GAMMA * R + rewards[i]
            returns[i] = R

        # normalize returns (stability)
        mean, std = returns.mean(), returns.std()
        std = std if std > 1e-8 else 1.0
        returns_norm = (returns - mean) / std

        # tensors
        states_tensor = torch.tensor(np.array(states), dtype=torch.float32, device=device).unsqueeze(0)  # (1,T,2)
        actions_tensor = torch.tensor(np.array(actions), dtype=torch.float32, device=device).view(-1, 1)  # (T,1)
        returns_tensor = torch.tensor(returns_norm, dtype=torch.float32, device=device).view(-1, 1)       # (T,1)

        # -------------------------
        # Catch & Climb signals
        # -------------------------
        baseline_prev = float(baseline)
        is_collapse, is_spike, collapse_int, spike_int = compute_climb_signals(
            ep_reward=episode_reward,
            baseline_prev=baseline_prev
        )

        # update baseline AFTER computing signal
        baseline = BASELINE_BETA * baseline + (1.0 - BASELINE_BETA) * float(episode_reward)

        # compute dynamic aux weights
        kl_w = min(AUX_KL_MAX, AUX_KL_COEF * (1.0 + 5.0 * collapse_int)) if is_collapse else 0.0
        sil_w = min(SIL_MAX,    SIL_COEF    * (1.0 + 5.0 * spike_int))   if is_spike else 0.0

        # decide extra updates
        extra_updates = 0
        if is_collapse or is_spike:
            extra_updates = int(min(EXTRA_UPDATES_CAP, EXTRA_UPDATES_ON_EVENT))

        # -------------------------
        # critic baseline & advantage
        # -------------------------
        with torch.no_grad():
            c_hx = torch.zeros((LSTM_LAYERS, 1, CRITIC_HIDDEN), device=device)
            c_cx = torch.zeros((LSTM_LAYERS, 1, CRITIC_HIDDEN), device=device)
            v, _ = value(states_tensor, (c_hx, c_cx))
            v = v.squeeze(0)  # (T,1)

            advantage = returns_tensor - v
            if ADV_NORM:
                advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        # -------------------------
        # prepare anchor probs once
        # -------------------------
        with torch.no_grad():
            anchor_prob = forward_policy_probs(anchor_policy, states_tensor)  # (T,1)

        # -------------------------
        # update actor (and extra updates if event)
        # -------------------------
        def actor_step():
            a_hx0 = torch.zeros((LSTM_LAYERS, 1, ACTOR_HIDDEN), device=device)
            a_cx0 = torch.zeros((LSTM_LAYERS, 1, ACTOR_HIDDEN), device=device)
            prob, _ = policy(states_tensor, (a_hx0, a_cx0))
            prob = prob.squeeze(0)  # (T,1)

            dist = Bernoulli(prob)
            log_prob = dist.log_prob(actions_tensor)   # (T,1)
            entropy = dist.entropy().mean()

            # base REINFORCE
            base_loss = -(log_prob * advantage.detach()).mean() - ENT_COEF * entropy

            # (1) anti-collapse: anchor KL clamp
            # enforce: don't drift away from best-known policy when you start collapsing
            kl_loss = 0.0
            if kl_w > 0.0:
                kl = bernoulli_kl(prob, anchor_prob).mean()
                kl_loss = float(kl_w) * kl

            # (2) spike = self-imitation: aggressively reinforce the "lucky" trajectory
            # simplest: extra negative log-likelihood of taken actions (BC) weighted by positive advantage
            sil_loss = 0.0
            if sil_w > 0.0:
                pos_adv = torch.clamp(advantage.detach(), min=0.0)
                # if all adv <= 0, it becomes 0 (safe)
                sil_loss = float(sil_w) * (-(log_prob * pos_adv).mean())

            loss = base_loss + kl_loss + sil_loss
            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), CLIP_NORM)
            optim.step()
            return loss, base_loss, kl_loss, sil_loss, entropy

        actor_loss, base_loss, kl_loss, sil_loss, entropy = actor_step()
        for _ in range(extra_updates):
            actor_loss, base_loss, kl_loss, sil_loss, entropy = actor_step()

        # -------------------------
        # critic update
        # -------------------------
        c_hx = torch.zeros((LSTM_LAYERS, 1, CRITIC_HIDDEN), device=device)
        c_cx = torch.zeros((LSTM_LAYERS, 1, CRITIC_HIDDEN), device=device)
        v_pred, _ = value(states_tensor, (c_hx, c_cx))
        v_pred = v_pred.squeeze(0)

        if USE_HUBER_VALUE_LOSS:
            value_loss = F.smooth_l1_loss(v_pred, returns_tensor)
        else:
            value_loss = F.mse_loss(v_pred, returns_tensor)

        value_optim.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(value.parameters(), CLIP_NORM)
        value_optim.step()

        # -------------------------
        # Anchor update (best policy snapshot)
        # -------------------------
        if episode_reward > best_reward:
            best_reward = float(episode_reward)
            anchor_policy.load_state_dict(policy.state_dict())
            anchor_policy.eval()

        # -------------------------
        # Logging
        # -------------------------
        writer.add_scalar("episode_reward", episode_reward, epoch)
        writer.add_scalar("baseline/ema", baseline, epoch)
        writer.add_scalar("baseline/prev", baseline_prev, epoch)
        writer.add_scalar("catch/is_collapse", is_collapse, epoch)
        writer.add_scalar("catch/is_spike", is_spike, epoch)
        writer.add_scalar("catch/collapse_int", collapse_int, epoch)
        writer.add_scalar("catch/spike_int", spike_int, epoch)
        writer.add_scalar("catch/kl_weight", kl_w, epoch)
        writer.add_scalar("catch/sil_weight", sil_w, epoch)
        writer.add_scalar("catch/extra_updates", extra_updates, epoch)

        writer.add_scalar("loss/actor_total", float(actor_loss.item()), epoch)
        writer.add_scalar("loss/actor_base", float(base_loss.item()), epoch)
        writer.add_scalar("loss/actor_kl", float(kl_loss) if isinstance(kl_loss, float) else float(kl_loss.item()), epoch)
        writer.add_scalar("loss/actor_sil", float(sil_loss) if isinstance(sil_loss, float) else float(sil_loss.item()), epoch)
        writer.add_scalar("loss/value", float(value_loss.item()), epoch)
        writer.add_scalar("stats/entropy", float(entropy.item()), epoch)
        writer.add_scalar("stats/adv_mean", float(advantage.mean().item()), epoch)
        writer.add_scalar("stats/adv_std", float(advantage.std().item()), epoch)
        writer.add_scalar("best_reward", best_reward, epoch)

        if epoch % 10 == 0:
            tag = "COLLAPSE" if is_collapse else ("SPIKE" if is_spike else "normal")
            print(
                f"Epoch {epoch:05d} | ep_reward {episode_reward:.1f} | "
                f"baseline_prev {baseline_prev:.2f} -> ema {baseline:.2f} | "
                f"{tag} | extra_upd={extra_updates}"
            )
            torch.save(policy.state_dict(), "lstm-policy.pt")


device = cpu
Epoch 00000 | ep_reward 42.0 | baseline_prev 0.00 -> ema 4.20 | SPIKE | extra_upd=2
Epoch 00010 | ep_reward 15.0 | baseline_prev 14.76 -> ema 14.78 | SPIKE | extra_upd=2
Epoch 00020 | ep_reward 24.0 | baseline_prev 22.66 -> ema 22.79 | SPIKE | extra_upd=2
Epoch 00030 | ep_reward 17.0 | baseline_prev 19.17 -> ema 18.95 | COLLAPSE | extra_upd=2
Epoch 00040 | ep_reward 19.0 | baseline_prev 21.55 -> ema 21.30 | COLLAPSE | extra_upd=2
Epoch 00050 | ep_reward 11.0 | baseline_prev 25.65 -> ema 24.19 | COLLAPSE | extra_upd=2
Epoch 00060 | ep_reward 22.0 | baseline_prev 21.73 -> ema 21.75 | SPIKE | extra_upd=2
Epoch 00070 | ep_reward 20.0 | baseline_prev 21.52 -> ema 21.37 | COLLAPSE | extra_upd=2
Epoch 00080 | ep_reward 24.0 | baseline_prev 21.72 -> ema 21.95 | SPIKE | extra_upd=2
Epoch 00090 | ep_reward 10.0 | baseline_prev 25.10 -> ema 23.59 | COLLAPSE | extra_upd=2
Epoch 00100 | ep_reward 10.0 | baseline_prev 19.97 -> ema 18.97 | COLLAPSE | extra_upd=2
Epoch 00110 | ep_reward 3

KeyboardInterrupt: 