In [None]:
import numpy as np
import torch
import gymnasium as gym
import time
import matplotlib.pyplot as plt

In [None]:
!pip install swig
!pip install gymnasium[classic_control]
!pip install gymnasium[mujoco]

In [None]:
import sys, torch
print("Python:", sys.version)
print("Exe:", sys.executable)
print("Torch:", torch.__version__)
print("Torch CUDA runtime:", torch.version.cuda)
print("cuDNN available:", torch.backends.cudnn.is_available())
print("CUDA visible devices:", torch.cuda.device_count(), torch.cuda.is_available())

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
#AI assisted, code taken from source: https://www.geeksforgeeks.org/deep-learning/reinforcement-learning-using-pytorch/
#How do I get a default NN model from pytorch that I can use as a value function estimator for my RL algorithm code
#extension to Continuous action space done with AI assistance
import torch
import torch.nn as nn
from torch.distributions import Categorical, Independent, Normal, TransformedDistribution
from torch.distributions.transforms import TanhTransform, AffineTransform
class actor_critic_cts(nn.Module):
  def __init__(self, state_dim, action_dim, hidden = 128, continuous = False, action_low = None, action_high = None, init_logstd = 0.0):
    super().__init__()
    self.continuous = continuous
    self.pi_net = nn.Sequential(
      nn.Linear(state_dim, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
    )
    if not self.continuous:
      self.pi_head = nn.Linear(hidden, action_dim)
    else:
      self.mu_head = nn.Linear(hidden, action_dim)
      self.log_std = nn.Parameter(torch.full((action_dim,), init_logstd))
      action_low  = torch.as_tensor(action_low,  dtype=torch.float32)
      action_high = torch.as_tensor(action_high, dtype=torch.float32)
      self.register_buffer("action_loc",   (action_high + action_low) / 2.0) #AI
      self.register_buffer("action_scale", (action_high - action_low) / 2.0) #AI

    self.v_net = nn.Sequential(
      nn.Linear(state_dim, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, 1)
    )

  def policy(self, state):
    h = self.pi_net(state)
    if not self.continuous:
      logits = self.pi_head(h)
      return Categorical(logits=logits)
    else:
      mu = self.mu_head(h)
      std = self.log_std.exp().expand_as(mu)
      base = Independent(Normal(mu, std), 1)
      return TransformedDistribution(  #AI
        base,
        [TanhTransform(cache_size=1),
         AffineTransform(loc=self.action_loc, scale=self.action_scale)]
      )

  def value(self, state):
    return self.v_net(state).squeeze(-1)

  def forward(self, state):
    dist = self.policy(state)
    V = self.value(state)
    return dist, V

In [None]:
#AI assisted, code taken from source: https://www.geeksforgeeks.org/deep-learning/reinforcement-learning-using-pytorch/
#How do I get a default NN model from pytorch that I can use as a value function estimator for my RL algorithm code
import torch
import torch.nn as nn
from torch.distributions import Categorical
class actor_critic_dsc(nn.Module):
  def __init__(self, state_dim, action_dim, hidden = 128):
    super().__init__()
    self.pi_net = nn.Sequential(
      nn.Linear(state_dim, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, action_dim)
    )
    self.v_net = nn.Sequential(
      nn.Linear(state_dim, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, 1)
    )

  def policy(self, state):
    logits = self.pi_net(state)
    return Categorical(logits=logits)

  def value(self, state):
    return self.v_net(state).squeeze(-1)

  def forward(self, state):
    dist = self.policy(state)
    V = self.value(state)
    return dist, V

In [None]:
def reset_state(env):
  s, _  = env.reset()
  return s

In [None]:
def GAE(r_hist, v_hist, next_val, traj_ends, y = 0.99, l = 0.95, terms = 150):
  r_hist = torch.tensor(r_hist, dtype=torch.float32)
  v_hist = torch.tensor(v_hist, dtype=torch.float32)
  traj_ends = torch.tensor(traj_ends, dtype=torch.float32)
  T = len(r_hist)
  v_tplus1 = torch.empty(T, dtype=torch.float32)
  if T > 1:
    v_tplus1[:-1] = v_hist[1:]
  if traj_ends[-1] == 1.0:
    v_tplus1[-1] = 0.0
  else:
    v_tplus1[-1] = next_val
  delta_in_t = r_hist + y*(1-traj_ends)*v_tplus1 - v_hist
  GAE_tensor = torch.zeros(T, dtype=torch.float32)
  for t in range(T):
    GAE_t = 0.0
    A_t_i = 0.0
    stop = 1.0
    y_factor = 1.0
    for i in range(1, 1 + min(T-t, terms)):
      A_t_i += stop*y_factor*delta_in_t[t + i-1]
      GAE_t += (l**(i-1))*A_t_i
      stop = stop * (1.0 - traj_ends[t + i-1])
      y_factor *= y
      if stop == 0.0:
          break
    GAE_tensor[t] = (1-l)*GAE_t
  return GAE_tensor


In [None]:
#C A R T P O L E
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
total_timesteps = 4000 #timesetps taken in each epoch
traj_max_timesteps = 1000
num_epochs = 40
num_GAE_terms = 150
stepsize_w = 1e-3
stepsize_theta = 3e-3
run_start_time = time.time()
best_mean_return = -float("inf")
net = actor_critic_dsc(state_dim, action_dim, hidden=128).to("cpu")

w = list(net.v_net.parameters())
theta = list(net.pi_net.parameters())
e_clip = 0.2

"""
SGD(lr=0.5) on both actor and critic is extremely high for VPG with neural nets
and categorical policies, and plain SGD here is brittle.
Switching to Adam with small lrs is a vanilla change (not clipping/normalization)
and typically the difference between crawling and clean convergence.
"""
actor_update = torch.optim.Adam(theta, lr=stepsize_theta)
critic_update = torch.optim.Adam(w, lr=stepsize_w)

all_traj_returns = []          # mean reward per trajectory (sum/length)
all_traj_epochs = []         # epoch index for each trajectory
epoch_spans = []

for k in range(num_epochs):
    epoch_states, epoch_actions, epoch_vals, epoch_adv, epoch_returns = [], [], [], [], [] # data for each trajectory roll-out
    epoch_t = 0;
    epoch_logp = []
    traj_returns = []
    epoch_start_idx = len(all_traj_returns)
    while epoch_t < total_timesteps:
      s = reset_state(env)
      traj_states, traj_actions, traj_vals, traj_rews, traj_ends = [], [], [], [], [] # data for each trajectory roll-out
      done = False
      traj_t = 0
      while (not done) and (traj_t<traj_max_timesteps):
        s = torch.as_tensor(s, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
          dist, V = net(s)
          a = dist.sample()
          logp_a = dist.log_prob(a)
        s_next, r, terminated, truncated, _ = env.step(a.item())
        done = terminated or truncated #use this instead of a defined for loop, because in gymnasium the trajectory might terminate in the environment when some success/failure conditions are met
        traj_rews.append(r)
        traj_states.append(s.squeeze(0))
        traj_vals.append(V.item())
        traj_actions.append(a.item())
        epoch_logp.append(logp_a.item())
        """terminated = true environment terminal (pole fell) → mask it.
            truncated = time limit → do not mask; bootstrap with V_sT
        """
        traj_ends.append(1.0 if terminated else 0.0)
        s = s_next
        traj_t += 1
        epoch_t += 1
      # AI IDENTIFIED FIXES
        if epoch_t >= total_timesteps:      # FIX: stop collecting if epoch budget reached
            break
      # ---- prevent ends[-1] crash if trajectory collected 0 steps ----
      if len(traj_rews) == 0:               # FIX: skip empty trajectories
        continue
      rews  = torch.tensor(traj_rews,  dtype=torch.float32, device=device)
      vals  = torch.tensor(traj_vals,  dtype=torch.float32, device=device)
      ends = torch.tensor(traj_ends, dtype=torch.float32, device=device)

      if ends[-1] == 1.0:  #implementation and use of dones to handle the ends of trajectory by truncation or termination was implemented with ChatGPT 5 assistance
        next_v = 0.0
      else:
        s_tplus1 = torch.as_tensor(s, dtype=torch.float32, device="cpu").unsqueeze(0) #obtain the state after the last action taken before trajectory ended
        with torch.no_grad(): #AI provided. Used so that our value function is a constant, and estimated strictly through NN output, and does not give along the graph
                              #that generated the next_v
                              #using .item() does ensure in part that the graph is not given along,
                              #but no_grad() allows us to have entirely skip the computation of the generation graph
            next_v = net.value(s_tplus1).item()

      A = GAE(traj_rews, traj_vals, next_v, traj_ends)
      ret = A + torch.tensor(traj_vals, dtype=torch.float32)

      epoch_states.extend(traj_states)
      epoch_actions.extend(traj_actions)
      epoch_vals.extend(traj_vals)
      epoch_adv.extend(A.tolist())
      epoch_returns.extend(ret.tolist())

      traj_len = len(traj_rews)
      # after a trajectory finishes
      traj_return = rews.sum().item()             # <-- sum, not mean
      traj_returns.append(traj_return)

      all_traj_returns.append(traj_return)        # rename your list
      all_traj_epochs.append(k)

    if len(epoch_states) == 0:
      print("empty epoch; continuing")
      continue

    mean_trajreward = float(np.mean(traj_returns))  if len(traj_returns) > 0 else 0.0
    std_trajreward  = float(np.std(traj_returns))  if len(traj_returns) > 0 else 0.0
    min_trajreward  = float(np.min(traj_returns))  if len(traj_returns) > 0 else 0.0
    max_trajreward  = float(np.max(traj_returns))  if len(traj_returns) > 0 else 0.0

    state_t = torch.stack(epoch_states).to(torch.float32).to(device)
    action_t = torch.tensor(epoch_actions, dtype=torch.int64, device=device)
    adv_t = torch.tensor(epoch_adv, dtype=torch.float32, device=device )
    ret_t = torch.tensor(epoch_returns, dtype=torch.float32, device=device)
    logp_old_t = torch.tensor(epoch_logp, dtype=torch.float32, device=device)

    v_step = net.value(state_t).squeeze(-1)                                  # V(s_t) (with grad)
    value_loss = torch.nn.functional.mse_loss(v_step, ret_t.detach())
    critic_update.zero_grad(); value_loss.backward(); critic_update.step()


    dist = net.policy(state_t)
    logp_t = dist.log_prob(action_t)
    r_th = torch.exp(logp_t - logp_old_t)
    L_epsilon = -torch.min(r_th*adv_t.detach(),
                           torch.clamp(r_th, 1.0 - e_clip, 1.0 + e_clip)*adv_t.detach()).mean()
    actor_update.zero_grad(); L_epsilon.backward(); actor_update.step()

    pi_loss_val = float(L_epsilon.item())
    v_loss_val  = float(value_loss.item())

    elapsed = time.time() - run_start_time
    if mean_trajreward > best_mean_return:
        best_mean_return = mean_trajreward
        torch.save(
            {
                "model": net.state_dict(),
                "actor_opt": actor_update.state_dict(),
                "critic_opt": critic_update.state_dict(),
                "epoch": k,
                "best_mean_return": best_mean_return,
            },
            "actor_critic_best1.pt",
        )

    print(
    f"epoch {k+1:03d}/{num_epochs} | steps(epoch) ~{len(epoch_returns):5d} "
    f"| return μ {mean_trajreward:6.2f} ± {std_trajreward:5.2f} (min {min_trajreward:5.1f}, max {max_trajreward:5.1f}) "
    f"| pi_loss {pi_loss_val:7.4f} | v_loss {v_loss_val:7.4f} "
    f"| best μ {best_mean_return:6.2f} | {elapsed:6.1f}s")
    epoch_end_idx = len(all_traj_returns)   # one past the last traj index of this epoch
    epoch_mean_for_plot = float(np.mean(all_traj_returns[epoch_start_idx:epoch_end_idx])) if epoch_end_idx > epoch_start_idx else np.nan
    epoch_spans.append((epoch_start_idx, epoch_end_idx, epoch_mean_for_plot))
# x-axis = global trajectory index
x = np.arange(len(all_traj_returns))

plt.figure(figsize=(12,5))

# 1) per-trajectory returns (blue)
traj_line, = plt.plot(x, all_traj_returns, linewidth=1.0, label="trajectory return")

# 2) per-epoch mean as a red step line (built from epoch_spans)
xs, ys = [], []
for (s, e, m) in epoch_spans:
    if e > s and np.isfinite(m):
        xs += [s, e]
        ys += [m, m]
epoch_step = None
if xs:
    epoch_step = plt.step(xs, ys, where="post", color="red", linewidth=2.5, label="epoch mean return")

# (optional) faint epoch boundaries — no legend for these
for (s, e, _) in epoch_spans:
    plt.axvline(e - 0.5, alpha=0.1, linewidth=1)

# 3) mark best epoch mean with a black star
best_scatter = None
valid = [(i, s, e, m) for i, (s, e, m) in enumerate(epoch_spans) if np.isfinite(m) and e > s]
if valid:
    i_best, s_best, e_best, m_best = max(valid, key=lambda t: t[3])
    x_best = 0.5 * (s_best + e_best - 1)
    best_scatter = plt.scatter([x_best], [m_best], marker='*', s=140, color='black', zorder=6, label='best epoch mean')
    plt.annotate(f"{m_best:.1f}", xy=(x_best, m_best), xytext=(8, 8),
                 textcoords="offset points", fontsize=9,
                 arrowprops=dict(arrowstyle="->", lw=1))

plt.xlabel("Trajectory index (across all epochs)")
plt.ylabel("Return (sum of rewards)")
plt.title("Per-trajectory returns (blue) with per-epoch mean (red step)")
plt.legend()  # uses labels set above
plt.tight_layout()
plt.show()

In [None]:
#I N V E R T E D  P E N D U L U M

env = gym.make("InvertedPendulum-v5")
state_dim = env.observation_space.shape[0]
action_space = env.action_space
total_timesteps = 4000 #timesetps taken in each epoch
traj_max_timesteps = 1001
num_epochs = 40
num_GAE_terms = 150
stepsize_w = 1e-3
stepsize_theta = 1e-3
run_start_time = time.time()
best_mean_return = -float("inf")
net = actor_critic_cts(state_dim, action_dim = action_space.shape[0], hidden=128, continuous = True, action_low=action_space.low, action_high=action_space.high).to("cpu")

w = list(net.v_net.parameters())
e_clip = 0.2

"""
SGD(lr=0.5) on both actor and critic is extremely high for VPG with neural nets
and categorical policies, and plain SGD here is brittle.
Switching to Adam with small lrs is a vanilla change (not clipping/normalization)
and typically the difference between crawling and clean convergence.
"""
actor_params = list(net.pi_net.parameters())
if net.continuous:
    actor_params += list(net.mu_head.parameters()) + [net.log_std]
else:
    actor_params += list(net.pi_head.parameters())
actor_update = torch.optim.Adam(actor_params, lr=stepsize_theta)

critic_update = torch.optim.Adam(w, lr=stepsize_w)

all_traj_returns = []          # mean reward per trajectory (sum/length)
all_traj_epochs = []         # epoch index for each trajectory
epoch_spans = []

for k in range(num_epochs):
    epoch_states, epoch_actions, epoch_vals, epoch_adv, epoch_returns = [], [], [], [], [] # data for each trajectory roll-out
    epoch_t = 0;
    epoch_logp = []
    traj_returns = []
    while epoch_t < total_timesteps:
      s = reset_state(env)
      traj_states, traj_actions, traj_vals, traj_rews, traj_ends = [], [], [], [], [] # data for each trajectory roll-out
      done = False
      traj_t = 0
      epoch_start_idx = len(all_traj_returns)
      while (not done) and (traj_t<traj_max_timesteps):
        s = torch.as_tensor(s, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
          dist, V = net(s)
          a = dist.sample()
          logp_a = dist.log_prob(a)
        a_env = a.squeeze(0).cpu().numpy().astype(np.float32)  #AI
        s_next, r, terminated, truncated, _ = env.step(a_env)
        done = terminated or truncated #use this instead of a defined for loop, because in gymnasium the trajectory might terminate in the environment when some success/failure conditions are met
        traj_rews.append(r)
        traj_states.append(s.squeeze(0))
        traj_vals.append(V.item())
        traj_actions.append(a.squeeze(0))
        epoch_logp.append(logp_a.squeeze(0))
        """terminated = true environment terminal (pole fell) → mask it.
            truncated = time limit → do not mask; bootstrap with V_sT
        """
        traj_ends.append(1.0 if terminated else 0.0)
        s = s_next
        traj_t += 1
        epoch_t += 1
      # AI IDENTIFIED FIXES
        if epoch_t >= total_timesteps:      # FIX: stop collecting if epoch budget reached
            break
      # ---- prevent ends[-1] crash if trajectory collected 0 steps ----
      if len(traj_rews) == 0:               # FIX: skip empty trajectories
        continue
      rews  = torch.tensor(traj_rews,  dtype=torch.float32)
      vals  = torch.tensor(traj_vals,  dtype=torch.float32)
      ends = torch.tensor(traj_ends, dtype=torch.float32)

      if ends[-1] == 1.0:  #implementation and use of dones to handle the ends of trajectory by truncation or termination was implemented with ChatGPT 5 assistance
        next_v = 0.0
      else:
        s_tplus1 = torch.as_tensor(s, dtype=torch.float32, device="cpu").unsqueeze(0) #obtain the state after the last action taken before trajectory ended
        with torch.no_grad(): #AI provided. Used so that our value function is a constant, and estimated strictly through NN output, and does not give along the graph
                              #that generated the next_v
                              #using .item() does ensure in part that the graph is not given along,
                              #but no_grad() allows us to have entirely skip the computation of the generation graph
            next_v = net.value(s_tplus1).item()

      A = GAE(traj_rews, traj_vals, next_v, traj_ends)
      ret = A + torch.tensor(traj_vals, dtype=torch.float32)

      epoch_states.extend(traj_states)
      epoch_actions.extend(traj_actions)
      epoch_vals.extend(traj_vals)
      epoch_adv.extend(A.tolist())
      epoch_returns.extend(ret.tolist())

      traj_len = len(traj_rews)
      # after a trajectory finishes
      traj_return = rews.sum().item()             # <-- sum, not mean
      traj_returns.append(traj_return)

      all_traj_returns.append(traj_return)        # rename your list
      all_traj_epochs.append(k)

    if len(epoch_states) == 0:
      print("empty epoch; continuing")
      continue

    mean_trajreward = float(np.mean(traj_returns))  if len(traj_returns) > 0 else 0.0
    std_trajreward  = float(np.std(traj_returns))  if len(traj_returns) > 0 else 0.0
    min_trajreward  = float(np.min(traj_returns))  if len(traj_returns) > 0 else 0.0
    max_trajreward  = float(np.max(traj_returns))  if len(traj_returns) > 0 else 0.0

    state_t = torch.stack(epoch_states).to(torch.float32)
    action_t = torch.stack(epoch_actions).to(torch.float32)
    adv_t = torch.tensor(epoch_adv, dtype=torch.float32)
    ret_t = torch.tensor(epoch_returns, dtype=torch.float32)
    logp_old_t = torch.tensor(epoch_logp, dtype=torch.float32)

    v_step = net.value(state_t).squeeze(-1)                                  # V(s_t) (with grad)
    value_loss = torch.nn.functional.mse_loss(v_step, ret_t.detach())
    critic_update.zero_grad(); value_loss.backward(); critic_update.step()

    # Clamp actions into open interval so atanh is well-defined when re-computing log_prob
    ###AIstart
    eps = 1e-6
    low  = (net.action_loc - net.action_scale + eps).unsqueeze(0)
    high = (net.action_loc + net.action_scale - eps).unsqueeze(0)
    action_t_clamped = torch.max(torch.min(action_t, high), low)
    ###AIend

    dist = net.policy(state_t)
    logp_t = dist.log_prob(action_t_clamped)
    r_th = torch.exp(logp_t - logp_old_t)
    L_epsilon = -torch.min(r_th*adv_t.detach(),
                           torch.clamp(r_th, 1.0 - e_clip, 1.0 + e_clip)*adv_t.detach()).mean()
    actor_update.zero_grad(); L_epsilon.backward(); actor_update.step()

    pi_loss_val = float(L_epsilon.item())
    v_loss_val  = float(value_loss.item())

    elapsed = time.time() - run_start_time
    if mean_trajreward > best_mean_return:
        best_mean_return = mean_trajreward
        torch.save(
            {
                "model": net.state_dict(),
                "actor_opt": actor_update.state_dict(),
                "critic_opt": critic_update.state_dict(),
                "epoch": k,
                "best_mean_return": best_mean_return,
            },
            "actor_critic_best2.pt",
        )

    print(
    f"epoch {k+1:03d}/{num_epochs} | steps(epoch) ~{len(epoch_returns):5d} "
    f"| return μ {mean_trajreward:6.2f} ± {std_trajreward:5.2f} (min {min_trajreward:5.1f}, max {max_trajreward:5.1f}) "
    f"| pi_loss {pi_loss_val:7.4f} | v_loss {v_loss_val:7.4f} "
    f"| best μ {best_mean_return:6.2f} | {elapsed:6.1f}s")
    epoch_end_idx = len(all_traj_returns)   # one past the last traj index of this epoch
    epoch_mean_for_plot = float(np.mean(all_traj_returns[epoch_start_idx:epoch_end_idx])) if epoch_end_idx > epoch_start_idx else np.nan
    epoch_spans.append((epoch_start_idx, epoch_end_idx, epoch_mean_for_plot))

# x-axis = global trajectory index
x = np.arange(len(all_traj_returns))

plt.figure(figsize=(12,5))

# 1) per-trajectory returns (blue)
traj_line, = plt.plot(x, all_traj_returns, linewidth=1.0, label="trajectory return")

# 2) per-epoch mean as a red step line (built from epoch_spans)
xs, ys = [], []
for (s, e, m) in epoch_spans:
    if e > s and np.isfinite(m):
        xs += [s, e]
        ys += [m, m]
epoch_step = None
if xs:
    epoch_step = plt.step(xs, ys, where="post", color="red", linewidth=2.5, label="epoch mean return")

# (optional) faint epoch boundaries — no legend for these
for (s, e, _) in epoch_spans:
    plt.axvline(e - 0.5, alpha=0.1, linewidth=1)

# 3) mark best epoch mean with a black star
best_scatter = None
valid = [(i, s, e, m) for i, (s, e, m) in enumerate(epoch_spans) if np.isfinite(m) and e > s]
if valid:
    i_best, s_best, e_best, m_best = max(valid, key=lambda t: t[3])
    x_best = 0.5 * (s_best + e_best - 1)
    best_scatter = plt.scatter([x_best], [m_best], marker='*', s=140, color='black', zorder=6, label='best epoch mean')
    plt.annotate(f"{m_best:.1f}", xy=(x_best, m_best), xytext=(8, 8),
                 textcoords="offset points", fontsize=9,
                 arrowprops=dict(arrowstyle="->", lw=1))

plt.xlabel("Trajectory index (across all epochs)")
plt.ylabel("Return (sum of rewards)")
plt.title("Per-trajectory returns (blue) with per-epoch mean (red step)")
plt.legend()  # uses labels set above
plt.tight_layout()
plt.show()

In [None]:
env = gym.make("HalfCheetah-v5")
state_dim = env.observation_space.shape[0]
action_space = env.action_space
total_timesteps = 4000 #timesetps taken in each epoch
traj_max_timesteps = 1001
num_epochs = 40
num_GAE_terms = 150
stepsize_w = 1e-3
stepsize_theta = 3e-4
run_start_time = time.time()
best_mean_return = -float("inf")
net = actor_critic_cts(state_dim, action_dim = action_space.shape[0], hidden=128, continuous = True, action_low=action_space.low, action_high=action_space.high, init_logstd=-1.0).to("cpu")

w = list(net.v_net.parameters())
e_clip = 0.2

"""
SGD(lr=0.5) on both actor and critic is extremely high for VPG with neural nets
and categorical policies, and plain SGD here is brittle.
Switching to Adam with small lrs is a vanilla change (not clipping/normalization)
and typically the difference between crawling and clean convergence.
"""
actor_params = list(net.pi_net.parameters())
if net.continuous:
    actor_params += list(net.mu_head.parameters()) + [net.log_std]
else:
    actor_params += list(net.pi_head.parameters())
actor_update = torch.optim.Adam(actor_params, lr=stepsize_theta)

critic_update = torch.optim.Adam(w, lr=stepsize_w)

all_traj_returns = []          # mean reward per trajectory (sum/length)
all_traj_epochs = []         # epoch index for each trajectory
epoch_spans = []

for k in range(num_epochs):
    epoch_states, epoch_actions, epoch_vals, epoch_adv, epoch_returns = [], [], [], [], [] # data for each trajectory roll-out
    epoch_t = 0;
    epoch_logp = []
    traj_returns = []
    while epoch_t < total_timesteps:
      s = reset_state(env)
      traj_states, traj_actions, traj_vals, traj_rews, traj_ends = [], [], [], [], [] # data for each trajectory roll-out
      done = False
      traj_t = 0
      epoch_start_idx = len(all_traj_returns)
      while (not done) and (traj_t<traj_max_timesteps):
        s = torch.as_tensor(s, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
          dist, V = net(s)
          a = dist.sample()
          logp_a = dist.log_prob(a)
        a_env = a.squeeze(0).cpu().numpy().astype(np.float32)  #AI
        s_next, r, terminated, truncated, _ = env.step(a_env)
        done = terminated or truncated #use this instead of a defined for loop, because in gymnasium the trajectory might terminate in the environment when some success/failure conditions are met
        traj_rews.append(r)
        traj_states.append(s.squeeze(0))
        traj_vals.append(V.item())
        traj_actions.append(a.squeeze(0))
        epoch_logp.append(logp_a.squeeze(0))
        """terminated = true environment terminal (pole fell) → mask it.
            truncated = time limit → do not mask; bootstrap with V_sT
        """
        traj_ends.append(1.0 if terminated else 0.0)
        s = s_next
        traj_t += 1
        epoch_t += 1
      # AI IDENTIFIED FIXES
        if epoch_t >= total_timesteps:      # FIX: stop collecting if epoch budget reached
            break
      # ---- prevent ends[-1] crash if trajectory collected 0 steps ----
      if len(traj_rews) == 0:               # FIX: skip empty trajectories
        continue
      rews  = torch.tensor(traj_rews,  dtype=torch.float32)
      vals  = torch.tensor(traj_vals,  dtype=torch.float32)
      ends = torch.tensor(traj_ends, dtype=torch.float32)

      if ends[-1] == 1.0:  #implementation and use of dones to handle the ends of trajectory by truncation or termination was implemented with ChatGPT 5 assistance
        next_v = 0.0
      else:
        s_tplus1 = torch.as_tensor(s, dtype=torch.float32, device="cpu").unsqueeze(0) #obtain the state after the last action taken before trajectory ended
        with torch.no_grad(): #AI provided. Used so that our value function is a constant, and estimated strictly through NN output, and does not give along the graph
                              #that generated the next_v
                              #using .item() does ensure in part that the graph is not given along,
                              #but no_grad() allows us to have entirely skip the computation of the generation graph
            next_v = net.value(s_tplus1).item()

      A = GAE(traj_rews, traj_vals, next_v, traj_ends)
      ret = A + torch.tensor(traj_vals, dtype=torch.float32)

      epoch_states.extend(traj_states)
      epoch_actions.extend(traj_actions)
      epoch_vals.extend(traj_vals)
      epoch_adv.extend(A.tolist())
      epoch_returns.extend(ret.tolist())

      traj_len = len(traj_rews)
      # after a trajectory finishes
      traj_return = rews.sum().item()             # <-- sum, not mean
      traj_returns.append(traj_return)

      all_traj_returns.append(traj_return)        # rename your list
      all_traj_epochs.append(k)

    if len(epoch_states) == 0:
      print("empty epoch; continuing")
      continue

    mean_trajreward = float(np.mean(traj_returns))  if len(traj_returns) > 0 else 0.0
    std_trajreward  = float(np.std(traj_returns))  if len(traj_returns) > 0 else 0.0
    min_trajreward  = float(np.min(traj_returns))  if len(traj_returns) > 0 else 0.0
    max_trajreward  = float(np.max(traj_returns))  if len(traj_returns) > 0 else 0.0

    state_t = torch.stack(epoch_states).to(torch.float32)
    action_t = torch.stack(epoch_actions).to(torch.float32)
    adv_t = torch.tensor(epoch_adv, dtype=torch.float32)
    ret_t = torch.tensor(epoch_returns, dtype=torch.float32)
    logp_old_t = torch.tensor(epoch_logp, dtype=torch.float32)

    v_step = net.value(state_t).squeeze(-1)                                  # V(s_t) (with grad)
    value_loss = torch.nn.functional.mse_loss(v_step, ret_t.detach())
    critic_update.zero_grad(); value_loss.backward(); critic_update.step()

    # Clamp actions into open interval so atanh is well-defined when re-computing log_prob
    ###AIstart
    eps = 1e-6
    low  = (net.action_loc - net.action_scale + eps).unsqueeze(0)
    high = (net.action_loc + net.action_scale - eps).unsqueeze(0)
    action_t_clamped = torch.max(torch.min(action_t, high), low)
    ###AIend

    dist = net.policy(state_t)
    logp_t = dist.log_prob(action_t_clamped)
    r_th = torch.exp(logp_t - logp_old_t)
    L_epsilon = -torch.min(r_th*adv_t.detach(),
                           torch.clamp(r_th, 1.0 - e_clip, 1.0 + e_clip)*adv_t.detach()).mean()
    actor_update.zero_grad(); L_epsilon.backward(); actor_update.step()

    pi_loss_val = float(L_epsilon.item())
    v_loss_val  = float(value_loss.item())

    elapsed = time.time() - run_start_time
    if mean_trajreward > best_mean_return:
        best_mean_return = mean_trajreward
        torch.save(
            {
                "model": net.state_dict(),
                "actor_opt": actor_update.state_dict(),
                "critic_opt": critic_update.state_dict(),
                "epoch": k,
                "best_mean_return": best_mean_return,
            },
            "actor_critic_best3.pt",
        )

    print(
    f"epoch {k+1:03d}/{num_epochs} | steps(epoch) ~{len(epoch_returns):5d} "
    f"| return μ {mean_trajreward:6.2f} ± {std_trajreward:5.2f} (min {min_trajreward:5.1f}, max {max_trajreward:5.1f}) "
    f"| pi_loss {pi_loss_val:7.4f} | v_loss {v_loss_val:7.4f} "
    f"| best μ {best_mean_return:6.2f} | {elapsed:6.1f}s")
    epoch_end_idx = len(all_traj_returns)   # one past the last traj index of this epoch
    epoch_mean_for_plot = float(np.mean(all_traj_returns[epoch_start_idx:epoch_end_idx])) if epoch_end_idx > epoch_start_idx else np.nan
    epoch_spans.append((epoch_start_idx, epoch_end_idx, epoch_mean_for_plot))

# x-axis = global trajectory index
x = np.arange(len(all_traj_returns))

plt.figure(figsize=(12,5))

# 1) per-trajectory returns (blue)
traj_line, = plt.plot(x, all_traj_returns, linewidth=1.0, label="trajectory return")

# 2) per-epoch mean as a red step line (built from epoch_spans)
xs, ys = [], []
for (s, e, m) in epoch_spans:
    if e > s and np.isfinite(m):
        xs += [s, e]
        ys += [m, m]
epoch_step = None
if xs:
    epoch_step = plt.step(xs, ys, where="post", color="red", linewidth=2.5, label="epoch mean return")

# (optional) faint epoch boundaries — no legend for these
for (s, e, _) in epoch_spans:
    plt.axvline(e - 0.5, alpha=0.1, linewidth=1)

# 3) mark best epoch mean with a black star
best_scatter = None
valid = [(i, s, e, m) for i, (s, e, m) in enumerate(epoch_spans) if np.isfinite(m) and e > s]
if valid:
    i_best, s_best, e_best, m_best = max(valid, key=lambda t: t[3])
    x_best = 0.5 * (s_best + e_best - 1)
    best_scatter = plt.scatter([x_best], [m_best], marker='*', s=140, color='black', zorder=6, label='best epoch mean')
    plt.annotate(f"{m_best:.1f}", xy=(x_best, m_best), xytext=(8, 8),
                 textcoords="offset points", fontsize=9,
                 arrowprops=dict(arrowstyle="->", lw=1))

plt.xlabel("Trajectory index (across all epochs)")
plt.ylabel("Return (sum of rewards)")
plt.title("Per-trajectory returns (blue) with per-epoch mean (red step)")
plt.legend()  # uses labels set above
plt.tight_layout()
plt.show()

In [None]:
env = gym.make("Swimmer-v5")
state_dim = env.observation_space.shape[0]
action_space = env.action_space
total_timesteps = 4000 #timesetps taken in each epoch
traj_max_timesteps = 1001
num_epochs = 40
num_GAE_terms = 150
stepsize_w = 1e-3
stepsize_theta = 3e-4
run_start_time = time.time()
best_mean_return = -float("inf")
net = actor_critic_cts(state_dim, action_dim = action_space.shape[0], hidden=128, continuous = True, action_low=action_space.low, action_high=action_space.high).to("cpu")

w = list(net.v_net.parameters())
e_clip = 0.2

"""
SGD(lr=0.5) on both actor and critic is extremely high for VPG with neural nets
and categorical policies, and plain SGD here is brittle.
Switching to Adam with small lrs is a vanilla change (not clipping/normalization)
and typically the difference between crawling and clean convergence.
"""
actor_params = list(net.pi_net.parameters())
if net.continuous:
    actor_params += list(net.mu_head.parameters()) + [net.log_std]
else:
    actor_params += list(net.pi_head.parameters())
actor_update = torch.optim.Adam(actor_params, lr=stepsize_theta)
theta = actor_params
critic_update = torch.optim.Adam(w, lr=stepsize_w)

all_traj_returns = []          # mean reward per trajectory (sum/length)
all_traj_epochs = []         # epoch index for each trajectory
epoch_spans = []

for k in range(num_epochs):
    epoch_states, epoch_actions, epoch_vals, epoch_adv, epoch_returns = [], [], [], [], [] # data for each trajectory roll-out
    epoch_t = 0;
    epoch_logp = []
    traj_returns = []
    while epoch_t < total_timesteps:
      s = reset_state(env)
      traj_states, traj_actions, traj_vals, traj_rews, traj_ends = [], [], [], [], [] # data for each trajectory roll-out
      done = False
      traj_t = 0
      epoch_start_idx = len(all_traj_returns)
      while (not done) and (traj_t<traj_max_timesteps):
        s = torch.as_tensor(s, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
          dist, V = net(s)
          a = dist.sample()
          logp_a = dist.log_prob(a).sum(-1)
        a_env = a.squeeze(0).cpu().numpy().astype(np.float32)  #AI
        s_next, r, terminated, truncated, _ = env.step(a_env)
        done = terminated or truncated #use this instead of a defined for loop, because in gymnasium the trajectory might terminate in the environment when some success/failure conditions are met
        traj_rews.append(r)
        traj_states.append(s.squeeze(0))
        traj_vals.append(V.item())
        traj_actions.append(a.squeeze(0))
        epoch_logp.append(logp_a.squeeze(0))
        """terminated = true environment terminal (pole fell) → mask it.
            truncated = time limit → do not mask; bootstrap with V_sT
        """
        traj_ends.append(1.0 if terminated else 0.0)
        s = s_next
        traj_t += 1
        epoch_t += 1
      # AI IDENTIFIED FIXES
        if epoch_t >= total_timesteps:      # FIX: stop collecting if epoch budget reached
            break
      # ---- prevent ends[-1] crash if trajectory collected 0 steps ----
      if len(traj_rews) == 0:               # FIX: skip empty trajectories
        continue
      rews  = torch.tensor(traj_rews,  dtype=torch.float32)
      vals  = torch.tensor(traj_vals,  dtype=torch.float32)
      ends = torch.tensor(traj_ends, dtype=torch.float32)

      if ends[-1] == 1.0:  #implementation and use of dones to handle the ends of trajectory by truncation or termination was implemented with ChatGPT 5 assistance
        next_v = 0.0
      else:
        s_tplus1 = torch.as_tensor(s, dtype=torch.float32, device="cpu").unsqueeze(0) #obtain the state after the last action taken before trajectory ended
        with torch.no_grad(): #AI provided. Used so that our value function is a constant, and estimated strictly through NN output, and does not give along the graph
                              #that generated the next_v
                              #using .item() does ensure in part that the graph is not given along,
                              #but no_grad() allows us to have entirely skip the computation of the generation graph
            next_v = net.value(s_tplus1).item()

      A = GAE(traj_rews, traj_vals, next_v, traj_ends)
      ret = A + torch.tensor(traj_vals, dtype=torch.float32)

      epoch_states.extend(traj_states)
      epoch_actions.extend(traj_actions)
      epoch_vals.extend(traj_vals)
      epoch_adv.extend(A.tolist())
      epoch_returns.extend(ret.tolist())

      traj_len = len(traj_rews)
      # after a trajectory finishes
      traj_return = rews.sum().item()             # <-- sum, not mean
      traj_returns.append(traj_return)

      all_traj_returns.append(traj_return)        # rename your list
      all_traj_epochs.append(k)

    if len(epoch_states) == 0:
      print("empty epoch; continuing")
      continue

    mean_trajreward = float(np.mean(traj_returns))  if len(traj_returns) > 0 else 0.0
    std_trajreward  = float(np.std(traj_returns))  if len(traj_returns) > 0 else 0.0
    min_trajreward  = float(np.min(traj_returns))  if len(traj_returns) > 0 else 0.0
    max_trajreward  = float(np.max(traj_returns))  if len(traj_returns) > 0 else 0.0

    state_t = torch.stack(epoch_states).to(torch.float32)
    action_t = torch.stack(epoch_actions).to(torch.float32)
    adv_t = torch.tensor(epoch_adv, dtype=torch.float32)
    ret_t = torch.tensor(epoch_returns, dtype=torch.float32)
    logp_old_t = torch.tensor(epoch_logp, dtype=torch.float32)

    v_step = net.value(state_t).squeeze(-1)                                  # V(s_t) (with grad)
    value_loss = torch.nn.functional.mse_loss(v_step, ret_t.detach())
    critic_update.zero_grad(); value_loss.backward(); critic_update.step()

    # Clamp actions into open interval so atanh is well-defined when re-computing log_prob
    ###AIstart
    eps = 1e-6
    low  = (net.action_loc - net.action_scale + eps).unsqueeze(0)
    high = (net.action_loc + net.action_scale - eps).unsqueeze(0)
    action_t_clamped = torch.max(torch.min(action_t, high), low)
    ###AIend

    dist = net.policy(state_t)
    logp_t = dist.log_prob(action_t_clamped)
    r_th = torch.exp(logp_t - logp_old_t)
    L_epsilon = -torch.min(r_th*adv_t.detach(),
                           torch.clamp(r_th, 1.0 - e_clip, 1.0 + e_clip)*adv_t.detach()).mean()
    actor_update.zero_grad(); L_epsilon.backward(); actor_update.step()

    pi_loss_val = float(L_epsilon.item())
    v_loss_val  = float(value_loss.item())

    elapsed = time.time() - run_start_time
    if mean_trajreward > best_mean_return:
        best_mean_return = mean_trajreward
        torch.save(
            {
                "model": net.state_dict(),
                "actor_opt": actor_update.state_dict(),
                "critic_opt": critic_update.state_dict(),
                "epoch": k,
                "best_mean_return": best_mean_return,
            },
            "actor_critic_best4.pt",
        )

    print(
    f"epoch {k+1:03d}/{num_epochs} | steps(epoch) ~{len(epoch_returns):5d} "
    f"| return μ {mean_trajreward:6.2f} ± {std_trajreward:5.2f} (min {min_trajreward:5.1f}, max {max_trajreward:5.1f}) "
    f"| pi_loss {pi_loss_val:7.4f} | v_loss {v_loss_val:7.4f} "
    f"| best μ {best_mean_return:6.2f} | {elapsed:6.1f}s")
    epoch_end_idx = len(all_traj_returns)   # one past the last traj index of this epoch
    epoch_mean_for_plot = float(np.mean(all_traj_returns[epoch_start_idx:epoch_end_idx])) if epoch_end_idx > epoch_start_idx else np.nan
    epoch_spans.append((epoch_start_idx, epoch_end_idx, epoch_mean_for_plot))

# x-axis = global trajectory index
x = np.arange(len(all_traj_returns))

plt.figure(figsize=(12,5))

# 1) per-trajectory returns (blue)
traj_line, = plt.plot(x, all_traj_returns, linewidth=1.0, label="trajectory return")

# 2) per-epoch mean as a red step line (built from epoch_spans)
xs, ys = [], []
for (s, e, m) in epoch_spans:
    if e > s and np.isfinite(m):
        xs += [s, e]
        ys += [m, m]
epoch_step = None
if xs:
    epoch_step = plt.step(xs, ys, where="post", color="red", linewidth=2.5, label="epoch mean return")

# (optional) faint epoch boundaries — no legend for these
for (s, e, _) in epoch_spans:
    plt.axvline(e - 0.5, alpha=0.1, linewidth=1)

# 3) mark best epoch mean with a black star
best_scatter = None
valid = [(i, s, e, m) for i, (s, e, m) in enumerate(epoch_spans) if np.isfinite(m) and e > s]
if valid:
    i_best, s_best, e_best, m_best = max(valid, key=lambda t: t[3])
    x_best = 0.5 * (s_best + e_best - 1)
    best_scatter = plt.scatter([x_best], [m_best], marker='*', s=140, color='black', zorder=6, label='best epoch mean')
    plt.annotate(f"{m_best:.1f}", xy=(x_best, m_best), xytext=(8, 8),
                 textcoords="offset points", fontsize=9,
                 arrowprops=dict(arrowstyle="->", lw=1))

plt.xlabel("Trajectory index (across all epochs)")
plt.ylabel("Return (sum of rewards)")
plt.title("Per-trajectory returns (blue) with per-epoch mean (red step)")
plt.legend()  # uses labels set above
plt.tight_layout()
plt.show()

In [None]:
env = gym.make("InvertedDoublePendulum-v5")
state_dim = env.observation_space.shape[0]
action_space = env.action_space
total_timesteps = 4000 #timesetps taken in each epoch
traj_max_timesteps = 1000
num_epochs = 60
num_GAE_terms = 150
stepsize_w = 1e-3
stepsize_theta = 3e-4
run_start_time = time.time()
best_mean_return = -float("inf")
net = actor_critic_cts(state_dim, action_dim = action_space.shape[0], hidden=128, continuous = True, action_low=action_space.low, action_high=action_space.high).to("cpu")

w = list(net.v_net.parameters())
e_clip = 0.2

"""
SGD(lr=0.5) on both actor and critic is extremely high for VPG with neural nets
and categorical policies, and plain SGD here is brittle.
Switching to Adam with small lrs is a vanilla change (not clipping/normalization)
and typically the difference between crawling and clean convergence.
"""
actor_params = list(net.pi_net.parameters())
if net.continuous:
    actor_params += list(net.mu_head.parameters()) + [net.log_std]
else:
    actor_params += list(net.pi_head.parameters())
actor_update = torch.optim.Adam(actor_params, lr=stepsize_theta)

critic_update = torch.optim.Adam(w, lr=stepsize_w)

all_traj_returns = []          # mean reward per trajectory (sum/length)
all_traj_epochs = []         # epoch index for each trajectory
epoch_spans = []

for k in range(num_epochs):
    epoch_states, epoch_actions, epoch_vals, epoch_adv, epoch_returns = [], [], [], [], [] # data for each trajectory roll-out
    epoch_t = 0;
    epoch_logp = []
    traj_returns = []
    while epoch_t < total_timesteps:
      s = reset_state(env)
      traj_states, traj_actions, traj_vals, traj_rews, traj_ends = [], [], [], [], [] # data for each trajectory roll-out
      done = False
      traj_t = 0
      epoch_start_idx = len(all_traj_returns)
      while (not done) and (traj_t<traj_max_timesteps):
        s = torch.as_tensor(s, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
          dist, V = net(s)
          a = dist.sample()
          logp_a = dist.log_prob(a).sum(-1)
        a_env = a.squeeze(0).cpu().numpy().astype(np.float32)  #AI
        s_next, r, terminated, truncated, _ = env.step(a_env)
        done = terminated or truncated #use this instead of a defined for loop, because in gymnasium the trajectory might terminate in the environment when some success/failure conditions are met
        traj_rews.append(r)
        traj_states.append(s.squeeze(0))
        traj_vals.append(V.item())
        traj_actions.append(a.squeeze(0))
        epoch_logp.append(logp_a.squeeze(0))
        """terminated = true environment terminal (pole fell) → mask it.
            truncated = time limit → do not mask; bootstrap with V_sT
        """
        traj_ends.append(1.0 if terminated else 0.0)
        s = s_next
        traj_t += 1
        epoch_t += 1
      # AI IDENTIFIED FIXES
        if epoch_t >= total_timesteps:      # FIX: stop collecting if epoch budget reached
            break
      # ---- prevent ends[-1] crash if trajectory collected 0 steps ----
      if len(traj_rews) == 0:               # FIX: skip empty trajectories
        continue
      rews  = torch.tensor(traj_rews,  dtype=torch.float32)
      vals  = torch.tensor(traj_vals,  dtype=torch.float32)
      ends = torch.tensor(traj_ends, dtype=torch.float32)

      if ends[-1] == 1.0:  #implementation and use of dones to handle the ends of trajectory by truncation or termination was implemented with ChatGPT 5 assistance
        next_v = 0.0
      else:
        s_tplus1 = torch.as_tensor(s, dtype=torch.float32, device="cpu").unsqueeze(0) #obtain the state after the last action taken before trajectory ended
        with torch.no_grad(): #AI provided. Used so that our value function is a constant, and estimated strictly through NN output, and does not give along the graph
                              #that generated the next_v
                              #using .item() does ensure in part that the graph is not given along,
                              #but no_grad() allows us to have entirely skip the computation of the generation graph
            next_v = net.value(s_tplus1).item()

      A = GAE(traj_rews, traj_vals, next_v, traj_ends)
      ret = A + torch.tensor(traj_vals, dtype=torch.float32)

      epoch_states.extend(traj_states)
      epoch_actions.extend(traj_actions)
      epoch_vals.extend(traj_vals)
      epoch_adv.extend(A.tolist())
      epoch_returns.extend(ret.tolist())

      traj_len = len(traj_rews)
      # after a trajectory finishes
      traj_return = rews.sum().item()             # <-- sum, not mean
      traj_returns.append(traj_return)

      all_traj_returns.append(traj_return)        # rename your list
      all_traj_epochs.append(k)

    if len(epoch_states) == 0:
      print("empty epoch; continuing")
      continue

    mean_trajreward = float(np.mean(traj_returns))  if len(traj_returns) > 0 else 0.0
    std_trajreward  = float(np.std(traj_returns))  if len(traj_returns) > 0 else 0.0
    min_trajreward  = float(np.min(traj_returns))  if len(traj_returns) > 0 else 0.0
    max_trajreward  = float(np.max(traj_returns))  if len(traj_returns) > 0 else 0.0

    state_t = torch.stack(epoch_states).to(torch.float32)
    action_t = torch.stack(epoch_actions).to(torch.float32)
    adv_t = torch.tensor(epoch_adv, dtype=torch.float32)
    ret_t = torch.tensor(epoch_returns, dtype=torch.float32)
    logp_old_t = torch.tensor(epoch_logp, dtype=torch.float32)

    v_step = net.value(state_t).squeeze(-1)                                  # V(s_t) (with grad)
    value_loss = torch.nn.functional.mse_loss(v_step, ret_t.detach())
    critic_update.zero_grad(); value_loss.backward(); critic_update.step()

    # Clamp actions into open interval so atanh is well-defined when re-computing log_prob
    ###AIstart
    eps = 1e-6
    low  = (net.action_loc - net.action_scale + eps).unsqueeze(0)
    high = (net.action_loc + net.action_scale - eps).unsqueeze(0)
    action_t_clamped = torch.max(torch.min(action_t, high), low)
    ###AIend

    dist = net.policy(state_t)
    logp_t = dist.log_prob(action_t_clamped)
    r_th = torch.exp(logp_t - logp_old_t)
    L_epsilon = -torch.min(r_th*adv_t.detach(),
                           torch.clamp(r_th, 1.0 - e_clip, 1.0 + e_clip)*adv_t.detach()).mean()
    actor_update.zero_grad(); L_epsilon.backward(); actor_update.step()

    pi_loss_val = float(L_epsilon.item())
    v_loss_val  = float(value_loss.item())

    elapsed = time.time() - run_start_time
    if mean_trajreward > best_mean_return:
        best_mean_return = mean_trajreward
        torch.save(
            {
                "model": net.state_dict(),
                "actor_opt": actor_update.state_dict(),
                "critic_opt": critic_update.state_dict(),
                "epoch": k,
                "best_mean_return": best_mean_return,
            },
            "actor_critic_best5.pt",
        )

    print(
    f"epoch {k+1:03d}/{num_epochs} | steps(epoch) ~{len(epoch_returns):5d} "
    f"| return μ {mean_trajreward:6.2f} ± {std_trajreward:5.2f} (min {min_trajreward:5.1f}, max {max_trajreward:5.1f}) "
    f"| pi_loss {pi_loss_val:7.4f} | v_loss {v_loss_val:7.4f} "
    f"| best μ {best_mean_return:6.2f} | {elapsed:6.1f}s")
    epoch_end_idx = len(all_traj_returns)   # one past the last traj index of this epoch
    epoch_mean_for_plot = float(np.mean(all_traj_returns[epoch_start_idx:epoch_end_idx])) if epoch_end_idx > epoch_start_idx else np.nan
    epoch_spans.append((epoch_start_idx, epoch_end_idx, epoch_mean_for_plot))

# x-axis = global trajectory index
x = np.arange(len(all_traj_returns))

plt.figure(figsize=(12,5))

# 1) per-trajectory returns (blue)
traj_line, = plt.plot(x, all_traj_returns, linewidth=1.0, label="trajectory return")

# 2) per-epoch mean as a red step line (built from epoch_spans)
xs, ys = [], []
for (s, e, m) in epoch_spans:
    if e > s and np.isfinite(m):
        xs += [s, e]
        ys += [m, m]
epoch_step = None
if xs:
    epoch_step = plt.step(xs, ys, where="post", color="red", linewidth=2.5, label="epoch mean return")

# (optional) faint epoch boundaries — no legend for these
for (s, e, _) in epoch_spans:
    plt.axvline(e - 0.5, alpha=0.1, linewidth=1)

# 3) mark best epoch mean with a black star
best_scatter = None
valid = [(i, s, e, m) for i, (s, e, m) in enumerate(epoch_spans) if np.isfinite(m) and e > s]
if valid:
    i_best, s_best, e_best, m_best = max(valid, key=lambda t: t[3])
    x_best = 0.5 * (s_best + e_best - 1)
    best_scatter = plt.scatter([x_best], [m_best], marker='*', s=140, color='black', zorder=6, label='best epoch mean')
    plt.annotate(f"{m_best:.1f}", xy=(x_best, m_best), xytext=(8, 8),
                 textcoords="offset points", fontsize=9,
                 arrowprops=dict(arrowstyle="->", lw=1))

plt.xlabel("Trajectory index (across all epochs)")
plt.ylabel("Return (sum of rewards)")
plt.title("Per-trajectory returns (blue) with per-epoch mean (red step)")
plt.legend()  # uses labels set above
plt.tight_layout()
plt.show()