In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import matplotlib.pyplot as plt

### Policy Gradient Theorem Derivation

The goal is: 
$$
J(\theta) = \mathbb{E}_{\tau \sim p_\theta}\left[\sum_{t=0}^{\infty} \gamma^t\,r(s_t, a_t)\right]
$$

We can write:
$$
J(\theta) = \int p_\theta(\tau)\left[\sum_{t=0}^{\infty} \gamma^t\,r(s_t,a_t)\right]\,d\tau.
$$

Taking the gradient with respect to $\theta$:
$$
\nabla_\theta J(\theta) = \nabla_\theta \int p_\theta(\tau)\left[\sum_{t=0}^{\infty}\gamma^t\,r(s_t,a_t)\right]d\tau.
$$
Assuming we can interchange the gradient and the integral:
$$
\nabla_\theta J(\theta) = \int \nabla_\theta p_\theta(\tau)\left[\sum_{t=0}^{\infty}\gamma^t\,r(s_t,a_t)\right]d\tau.
$$

Applying the Log-Likelihood trick gives:
$$
\nabla_\theta J(\theta) = \int p_\theta(\tau)\,\nabla_\theta \log p_\theta(\tau) \left[\sum_{t=0}^{\infty}\gamma^t\,r(s_t,a_t)\right]d\tau.
$$
In expectation notation:
$$
\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta}\left[\nabla_\theta \log p_\theta(\tau) \,\sum_{t=0}^{\infty}\gamma^t\,r(s_t,a_t)\right].
$$


Since the trajectory probability factorizes as
$$
p_\theta(\tau) = p(s_0)\prod_{t=0}^{\infty} \left[\pi_\theta(a_t \mid s_t) \,P(s_{t+1} \mid s_t,a_t)\right],
$$
its logarithm is
$$
\log p_\theta(\tau) = \log p(s_0) + \sum_{t=0}^{\infty}\Bigl[\log \pi_\theta(a_t \mid s_t) + \log P(s_{t+1} \mid s_t,a_t)\Bigr].
$$
Only the terms $\log \pi_\theta(a_t \mid s_t)$ depend on $\theta$. Therefore,
$$
\nabla_\theta \log p_\theta(\tau) = \sum_{t=0}^{\infty}\nabla_\theta \log \pi_\theta(a_t \mid s_t).
$$
Substitute back into our gradient expression:
$$
\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta}\left[\left(\sum_{t=0}^{\infty}\nabla_\theta \log \pi_\theta(a_t \mid s_t)\right)\left(\sum_{t=0}^{\infty}\gamma^t\,r(s_t,a_t)\right)\right].
$$

For each time step t, decompose the total return as:
$$
\sum_{m=0}^{\infty}\gamma^m\,r(s_m,a_m)
=\underbrace{\sum_{m=0}^{t-1}\gamma^m\,r(s_m,a_m)}_{\text{past (independent of \(a_t\))}}
+\underbrace{\sum_{m=t}^{\infty}\gamma^m\,r(s_m,a_m)}_{\text{future (dependent on \(a_t\))}}.
$$
Since the past portion is independent of $(a_t)$, its contribution to $(\nabla_\theta \log \pi_\theta(a_t\mid s_t))$ vanishes in expectation. Hence, we have:
$$
\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta}\left[\sum_{t=0}^{\infty}\nabla_\theta \log \pi_\theta(a_t \mid s_t)\left(\sum_{m=t}^{\infty}\gamma^m\,r(s_m,a_m)\right)\right].
$$

Since the Q function is:
$$
Q^\pi(s_t,a_t) = \mathbb{E}\left[\sum_{k=0}^{\infty}\gamma^k\,r(s_{t+k},a_{t+k}) \,\Bigm|\, s_t,a_t\right].
$$
Thus, the partial sum $(\sum_{m=t}^{\infty}\gamma^m\,r(s_m,a_m)$ is an unbiased sample of $(Q^\pi(s_t,a_t))$. Therefore,
$$
\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta}\left[\sum_{t=0}^{\infty}\nabla_\theta \log \pi_\theta(a_t \mid s_t)\,Q^\pi(s_t,a_t)\right].
$$


Each trajectory $(\tau)$ consists of a sequence of state–action pairs $((s_0,a_0), (s_1,a_1), \ldots)$. The sum over time steps is equivalent to taking an expectation with respect to the **discounted occupancy measure** $(\mu_\pi(s,a))$, which represents the (normalized) distribution of state–action pairs visited by $(\pi_\theta)$. Thus, we can write:
$$
\nabla_\theta J(\theta) = \mathbb{E}_{(s,a) \sim \mu_\pi}\left[\nabla_\theta \log \pi_\theta(a \mid s)\,Q^\pi(s,a)\right].
$$

---

**Final Policy Gradient Theorem**

$$
\boxed{
\nabla_\theta J(\theta) = \mathbb{E}_{(s,a) \sim \mu_\pi}\left[\nabla_\theta \log \pi_\theta(a \mid s)\,Q^\pi(s,a)\right].
}
$$
A common variance-reduced version uses the advantage function $A^\pi(s,a)=Q^\pi(s,a)-V^\pi(s)$:
$$
\boxed{
\nabla_\theta J(\theta) = \mathbb{E}_{(s,a) \sim \mu_\pi}\left[\nabla_\theta \log \pi_\theta(a \mid s)\,A^\pi(s,a)\right].
}
$$


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

# ---------------------
# Environment: CartPole-v1
# ---------------------
env_name = "CartPole-v1"
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]  # 4 for CartPole
action_dim = env.action_space.n 

def train_online(agent, env, episodes=500, method='REINFORCE'):
    rewards_per_episode = []
    
    # Set update frequency based on method
    if method == 'REINFORCE':
        update_freq = 'episode'
    elif method == 'A2C':
        update_freq = 'n_steps'
        n_steps = 5  # Update every 5 steps
    elif method == 'PPO':
        update_freq = 'n_episodes'
        n_episodes = 5  # Update after 5 episodes
    elif method in ['NPG', 'TRPO']:
        update_freq = 'episode'
    
    # Initialize buffers for different update frequencies
    buffer = []
    steps_since_update = 0
    episodes_since_update = 0
    
    for ep in range(episodes):
        state, _ = env.reset()
        episode_traj = []
        total_reward = 0
        done = False
        
        while not done:
            if method in ['REINFORCE', 'NPG', 'TRPO']:
                action, log_prob = agent.select_action(state)
                next_state, reward, done, truncated, _ = env.step(action)
                transition = (state, action, log_prob, reward, None, next_state, done or truncated)
            elif method in ['A2C', 'PPO']:
                action, log_prob, value = agent.select_action(state)
                next_state, reward, done, truncated, _ = env.step(action)
                transition = (state, action, log_prob, reward, value, next_state, done or truncated)
            
            episode_traj.append(transition)
            buffer.append(transition)
            total_reward += reward
            state = next_state
            
            # Update based on frequency
            if update_freq == 'n_steps':
                steps_since_update += 1
                if steps_since_update >= n_steps or done:
                    if method == 'A2C':
                        agent.update_td(buffer)  # Use TD targets
                    buffer = []
                    steps_since_update = 0
        
        rewards_per_episode.append(total_reward)
        
        # Update after episode if that's the chosen frequency
        if update_freq == 'episode':
            if method == 'REINFORCE':
                rewards = [t[3] for t in episode_traj]
                log_probs = [t[2] for t in episode_traj]
                agent.update(rewards, log_probs)
            elif method in ['NPG', 'TRPO']:
                # For NPG/TRPO, we might need more data
                data = [episode_traj]
                rewards = [t[3] for t in episode_traj]
                log_probs = [t[2] for t in episode_traj]
                if method == 'NPG':
                    agent.update(rewards, log_probs, data)
                else:  # TRPO
                    old_log_probs = [[t[2] for t in episode_traj]]
                    agent.update([episode_traj], old_log_probs)
        
        # Update after N episodes for methods like PPO
        if update_freq == 'n_episodes':
            episodes_since_update += 1
            if episodes_since_update >= n_episodes:
                if method == 'PPO':
                    agent.update_td(buffer)  # Use TD targets
                buffer = []
                episodes_since_update = 0
        
        if (ep+1) % 50 == 0:
            print(f"Online Episode {ep+1}, Total Reward: {total_reward}")
    
    return rewards_per_episode

def collect_trajectories(agent, env, num_episodes=50, method='REINFORCE'):
    trajectories = []
    for ep in range(num_episodes):
        state, _ = env.reset()
        traj = []
        done = False
        while not done:
            if method in ['REINFORCE', 'NPG', 'TRPO']:
                action, log_prob = agent.select_action(state)
                # For actor-critic based methods, we may get additional info.
                next_state, reward, done, truncated, _ = env.step(action)
                traj.append((state, action, log_prob, reward, None, next_state, done or truncated))
            elif method in ['A2C', 'PPO']:
                action, log_prob, value = agent.select_action(state)
                next_state, reward, done, truncated, _ = env.step(action)
                traj.append((state, action, log_prob, reward, value, next_state, done or truncated))
            state = next_state
        trajectories.append(traj)
    return trajectories

def train_offline(agent, offline_data, epochs=10, method='REINFORCE'):
    # offline_data is a list of trajectories pre-collected
    rewards_list = []
    for epoch in range(epochs):
        total_reward = 0
        for traj in offline_data:
            # For offline REINFORCE, we compute the update over each trajectory in the dataset.
            if method == 'REINFORCE':
                rewards = [entry[3] for entry in traj]
                log_probs = [entry[2] for entry in traj]
                total_reward += sum(rewards)
                agent.update(rewards, log_probs)
            elif method == 'A2C':
                total_reward += sum([entry[3] for entry in traj])
                agent.update(traj)
            elif method == 'NPG':
                rewards = [entry[3] for entry in traj]
                log_probs = [entry[2] for entry in traj]
                data = offline_data  # use entire offline dataset
                agent.update(rewards, log_probs, data)
            elif method == 'TRPO':
                old_log_probs = [ [entry[2] for entry in traj] ]
                agent.update([traj], old_log_probs)
                total_reward += sum([entry[3] for entry in traj])
            elif method == 'PPO':
                agent.update([traj])
                total_reward += sum([entry[3] for entry in traj])
        rewards_list.append(total_reward / len(offline_data))
        print(f"Offline Epoch {epoch+1}, Average Total Reward per Trajectory: {rewards_list[-1]:.2f}")
    return rewards_list

---

### REINFORCE


**Objective:**

The expected return is given by

$$
J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^{T-1}\gamma^t\,r(s_t,a_t)\right],
$$

where $ \tau = (s_0,a_0,\dots,s_{T-1},a_{T-1}) $ is a complete episode.

**Policy Gradient Theorem (using Monte Carlo return):**

The gradient is

$$
\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t\mid s_t) \, G_t\right],
$$

with the return

$$
G_t = \sum_{k=t}^{T-1}\gamma^{k-t}\,r(s_k,a_k).
$$

**Update Rule:**

For each episode, update

$$
\theta \leftarrow \theta + \alpha \sum_{t=0}^{T-1} G_t \, \nabla_\theta \log \pi_\theta(a_t\mid s_t).
$$

In [5]:

# Policy network for REINFORCE (outputs action probabilities)
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc_policy = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, x):
        x = torch.nn.ReLU(self.fc1(x))
        logits = self.fc_policy(x)
        return torch.softmax(logits, dim=-1)

class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=1e-2, gamma=0.99):
        self.policy_net = PolicyNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.gamma = gamma

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        probs = self.policy_net(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)
    
    def update(self, rewards, log_probs):
        # Compute returns G_t = r_t + gamma*r_{t+1} + ... for each time step t.
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float).to(device) #[r_0,r_0+r_1*\gamma,r_0+r_1*\gamma+r_2*\gamma^2,...]
        # Optional normalization for stability:
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss -= log_prob * G  # Note: gradient ascent on expected return.
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [None]:
agent = REINFORCEAgent(state_dim, action_dim) 
online_rewards = train_online(agent, env, episodes=500, method="REINFORCE")
plt.figure(figsize=(12,5))
plt.plot(online_rewards, label="Online Rewards")
plt.xlabel("Episode / Epoch")
plt.ylabel("Reward")
plt.legend()
plt.show()

### Actor Critic (A2C)

**Architecture:**

We have an actor (policy) and a critic (value function). The critic estimates

$$
V^\pi(s) \approx \mathbb{E}\left[G_t \mid s_t = s\right].
$$

**Advantage Estimate:**

A common choice is the one-step temporal difference (TD) error:

$$
A_t = r(s_t,a_t) + \gamma V^\pi(s_{t+1}) - V^\pi(s_t).
$$

**Loss Functions:**

- **Actor Loss:**

$$
L_{\text{actor}}(\theta) = -\mathbb{E}\left[\log \pi_\theta(a_t\mid s_t)\,A_t\right].
$$

- **Critic Loss:**

$$
L_{\text{critic}}(\phi) = \frac{1}{2}\mathbb{E}\left[\left(r(s_t,a_t) + \gamma V_\phi(s_{t+1}) - V_\phi(s_t)\right)^2\right].
$$

**Update Rules:**

- **Actor Update:**

$$
\theta \leftarrow \theta + \alpha\,\nabla_\theta \log \pi_\theta(a_t\mid s_t) \, A_t.
$$

- **Critic Update:**

$$
\phi \leftarrow \phi - \beta\,\nabla_\phi L_{\text{critic}}(\phi).
$$


In [None]:
# Actor-Critic network for A2C/PPO (shared encoder, two heads: policy & value)
class ActorCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(ActorCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc_policy = nn.Linear(hidden_dim, action_dim)
        self.fc_value = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = torch.torch.nn.ReLU(self.fc1(x))
        policy_logits = self.fc_policy(x)
        policy = torch.softmax(policy_logits, dim=-1)
        value = self.fc_value(x)
        return policy, value

class A2CAgent:
    def __init__(self, state_dim, action_dim, actor_lr=1e-3, critic_lr=1e-3, gamma=0.99):
        self.ac_net = ActorCriticNetwork(state_dim, action_dim).to(device)
        self.actor_optimizer = optim.Adam(self.ac_net.fc_policy.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.ac_net.fc_value.parameters(), lr=critic_lr)
        self.gamma = gamma

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        policy, value = self.ac_net(state)
        m = Categorical(policy)
        action = m.sample()
        return action.item(), m.log_prob(action), value

    def update(self, trajectory):
        # trajectory: list of (state, action, log_prob, reward, value, next_state, done)
        policy_losses = []
        value_losses = []
        R = 0
        returns = []
        for (_, _, _, reward, _, _, _) in reversed(trajectory):
            R = reward + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        for (s, a, log_prob, reward, value, s_next, done), R in zip(trajectory, returns):
            advantage = R - value.item()
            policy_losses.append(-log_prob * advantage)
            value_losses.append(nn.functional.mse_loss(value, torch.tensor([R], dtype=torch.float).to(device)))
        
        self.actor_optimizer.zero_grad()
        actor_loss = torch.stack(policy_losses).sum()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        self.critic_optimizer.zero_grad()
        critic_loss = torch.stack(value_losses).sum()
        critic_loss.backward()
        self.critic_optimizer.step()

    # Add this method to the A2CAgent class
    def update_td(self, transitions):
        # Extract all data from transitions into tensors
        states = torch.FloatTensor([t[0] for t in transitions]).to(device)
        actions = torch.LongTensor([t[1] for t in transitions]).to(device)
        log_probs = torch.stack([t[2] for t in transitions]).to(device)
        rewards = torch.FloatTensor([t[3] for t in transitions]).to(device)
        values = torch.stack([t[4] for t in transitions]).to(device).squeeze()
        next_states = torch.FloatTensor([t[5] for t in transitions]).to(device)
        dones = torch.FloatTensor([float(t[6]) for t in transitions]).to(device)
        
        # Calculate next state values for TD targets
        with torch.no_grad():
            _, next_values = self.ac_net(next_states)
            next_values = next_values.squeeze()
            next_values = next_values * (1 - dones)
        
        # Calculate TD targets: r + gamma * V(s')
        td_targets = rewards + self.gamma * next_values
        
        # Calculate advantages: TD error
        advantages = td_targets - values.detach()
        
        # Optional: Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Policy loss using advantages
        policy_loss = -(log_probs * advantages).mean()
        
        # Value loss using TD targets
        value_loss = nn.functional.mse_loss(values, td_targets)
        
        # Update actor
        self.actor_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.actor_optimizer.step()
        
        # Update critic
        self.critic_optimizer.zero_grad()
        value_loss.backward()
        self.critic_optimizer.step()


### Natural Policy Gradient (NPG)

**Key Idea:**

The standard gradient is preconditioned by the inverse Fisher information matrix $F(\theta) $ to obtain the natural gradient:

$$
\Delta \theta = F(\theta)^{-1} \, \nabla_\theta J(\theta),
$$

where

$$
F(\theta) = \mathbb{E}_{s \sim d^\pi,\, a \sim \pi_\theta}\left[\nabla_\theta \log \pi_\theta(a\mid s) \, \nabla_\theta \log \pi_\theta(a\mid s)^\top\right].
$$

**Update Rule:**

$$
\theta \leftarrow \theta + \alpha \, \Delta \theta = \theta + \alpha \, F(\theta)^{-1} \, \nabla_\theta J(\theta).
$$

In [None]:
class NPGAgent:
    def __init__(self, state_dim, action_dim, lr=1e-2, gamma=0.99):
        self.policy_net = PolicyNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)  # Base optimizer (we will precondition manually)
        self.gamma = gamma

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        probs = self.policy_net(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)
    
    def compute_gradient(self, rewards, log_probs):
        # Similar to REINFORCE, compute the gradient vector.
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        grad_vector = []
        for log_prob, G in zip(log_probs, returns):
            grad_vector.append(G * log_prob)
        # Sum gradients (each log_prob is a scalar multiplied by the gradient vector of the policy network)
        # We compute the gradient of the loss with respect to parameters manually.
        loss = -torch.stack(grad_vector).sum()
        self.optimizer.zero_grad()
        loss.backward()
        # Gather the gradient vector from all parameters into a single flattened vector.
        grad_flat = torch.cat([p.grad.view(-1) for p in self.policy_net.parameters()])
        return grad_flat.detach()
    
    def estimate_fisher(self, data, num_samples=10):
        # Estimate the Fisher Information Matrix from collected data.
        grads = []
        for _ in range(num_samples):
            # Sample one trajectory from data (offline dataset) or one rollout
            traj = data[np.random.randint(len(data))]
            # For simplicity, use the first time step.
            s, a, _, _, _, _, _ = traj[0]
            s = torch.FloatTensor(s).unsqueeze(0).to(device)
            probs = self.policy_net(s)
            m = Categorical(probs)
            log_prob = m.log_prob(torch.tensor(a).to(device))
            self.policy_net.zero_grad()
            log_prob.backward()
            grad_sample = torch.cat([p.grad.view(-1) for p in self.policy_net.parameters()])
            grads.append(grad_sample.unsqueeze(0))
        grads = torch.cat(grads, dim=0)  # [num_samples, D]
        # Fisher is approximated by the outer product of the gradients
        F = torch.matmul(grads.t(), grads) / num_samples
        return F.detach()
    
    def update(self, rewards, log_probs, data):
        # Compute regular REINFORCE gradient:
        g = self.compute_gradient(rewards, log_probs)
        # Estimate Fisher matrix using some offline data:
        F = self.estimate_fisher(data)
        # Solve for natural gradient: Δθ = F^{-1}g
        # For simplicity, we use torch.linalg.solve (assuming F is invertible)
        natural_grad = torch.linalg.solve(F, g)
        # Update parameters manually:
        index = 0
        for p in self.policy_net.parameters():
            numel = p.numel()
            p.data += 1e-2 * natural_grad[index:index+numel].view_as(p)
            index += numel



### Trust Region Policy Optimization (TRPO)

The goal of reinforcement learning is to maximize the expected return:

$$
J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]=\eta(\pi_{\theta})
$$
Consider the expected return under the new policy $\pi_{\theta'}$:

$$
J(\theta') = \mathbb{E}_{\tau \sim \pi_{\theta'}}[R(\tau)]
$$

Using importance sampling with respect to the old policy $\pi_{\theta_{\text{old}}}$, we rewrite:

$$
J(\theta') = \mathbb{E}_{\tau \sim \pi_{\theta_{\text{old}}}}\left[\frac{\pi_{\theta'}(\tau)}{\pi_{\theta_{\text{old}}}(\tau)}R(\tau)\right]
$$

However, directly computing $\frac{\pi_{\theta'}(\tau)}{\pi_{\theta_{\text{old}}}(\tau)}$ over entire trajectories is challenging. Therefore, TRPO approximates this by using single-step importance sampling and the advantage function $A^{\pi_{\theta_{\text{old}}}}(s,a)$:

$$
J(\theta') \approx \eta(\pi_{\theta_{\text{old}}}) + \sum_s d^{\pi_{\theta_{\text{old}}}}(s)\sum_a \pi_{\theta'}(a|s)A^{\pi_{\theta_{\text{old}}}}(s,a)
$$

Here, we define the surrogate objective clearly as:

$$
L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) = \eta(\pi_{\theta_{\text{old}}}) + \sum_s d^{\pi_{\theta_{\text{old}}}}(s)\sum_a \pi_{\theta'}(a|s)A^{\pi_{\theta_{\text{old}}}}(s,a)
$$

Recall the **Conservative Policy Iteration (CPI)** lower bound:
$$
L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) \geq \eta(\pi_{\theta_{\text{old}}}) + \sum_{s} d^{\pi_{\theta_{\text{old}}}}(s)\sum_{a}\pi_{\theta'}(a|s)A^{\pi_{\theta_{\text{old}}}}(s,a) - \frac{2\gamma\epsilon}{(1-\gamma)^2} D_{\text{KL}}^{\text{max}}(\pi_{\theta_{\text{old}}}\|\pi_{\theta'})
$$
The CPI bound above is theoretically insightful but practically restrictive due to the term $D_{\text{KL}}^{\text{max}}(\pi||\pi')$, which requires bounding the divergence at **every state**. 

In practice, we replace the "max KL divergence" with an **average KL divergence** to simplify computation:

$$
D_{\text{KL}}^{\text{max}}(\pi_{\theta_{\text{old}}}\|\pi_{\theta'}) \approx \mathbb{E}_{s \sim d^{\pi_{\theta_{\text{old}}}}}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(\cdot|s)\|\pi_{\theta'}(\cdot|s))]
$$

This leads to a more manageable lower bound:

$$
L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) \geq \eta(\pi_{\theta_{\text{old}}}) + \sum_{s} d^{\pi_{\theta_{\text{old}}}}(s)\sum_{a}\pi_{\theta'}(a|s)A^{\pi_{\theta_{\text{old}}}}(s,a) - \frac{2\gamma\epsilon}{(1-\gamma)^2}\cdot\mathbb{E}_{s \sim d^{\pi_{\theta_{\text{old}}}}}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(\cdot|s)\|\pi_{\theta'}(\cdot|s))]
$$
To ensure policy improvement, we aim to maximize the lower bound above. Equivalently, we can pose this as a constrained optimization problem. Specifically, we seek a new policy $\pi'$ that maximizes (since $\eta(\pi_{\theta_{\text{old}}})$ is constant):

$$
L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) = \sum_{s} d^{\pi_{\theta_{\text{old}}}}(s)\sum_{a}\pi_{\theta'}(a|s)A^{\pi_{\theta_{\text{old}}}}(s,a)=\mathbb{E}_{s,a \sim \pi_{\theta_{\text{old}}}} \left[ \frac{\pi_{\theta'}(a|s)}{\pi_{\theta_{old}}(a|s)}A^{\pi_{\theta_{\text{old}}}}(s,a)\right]
$$

subject to a constraint on the KL divergence between the old policy and the new policy:

$$
\mathbb{E}_{s \sim d^{\pi_{\theta_{\text{old}}}}}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(\cdot|s)\|\pi_{\theta'}(\cdot|s))] \leq \delta
$$

Here, $\delta$ is a hyperparameter chosen to control how aggressively the policy can change in each update.

Thus, the optimization becomes explicitly:

$$
\max_{\pi_{\theta'}} L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) \quad\text{s.t.}\quad \mathbb{E}_{s \sim d^{\pi_{\theta_{\text{old}}}}}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(\cdot|s)\|\pi_{\theta'}(\cdot|s))] \leq \delta
$$

We approximate the surrogate objective around $\theta_{\text{old}}$ with a first-order (linear) Taylor expansion:

$$
L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) \approx L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta_{\text{old}}}) + \nabla_{\theta} L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta})_{\theta =\theta_{old}}^\top (\theta' - \theta_{\text{old}})
$$

For the KL divergence constraint, we apply a second-order (quadratic) Taylor expansion around $\theta_{\text{old}}$:

$$
\mathbb{E}_{s \sim d^{\pi_{\theta_{\text{old}}}}}\left[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(\cdot|s)||\pi_{\theta'}(\cdot|s))\right] \approx \frac{1}{2}(\theta' - \theta_{\text{old}})^\top F(\theta_{\text{old}})(\theta' - \theta_{\text{old}})
$$

where the Fisher Information Matrix (FIM), $F(\theta_{\text{old}})$, is defined as:

$$
F(\theta_{\text{old}}) = \mathbb{E}_{s,a \sim \pi_{\theta_{\text{old}}}}\left[\nabla_{\theta}\log\pi_{\theta}(a|s)\nabla_{\theta}\log\pi_{\theta}(a|s)^\top\right]\Big|_{\theta=\theta_{\text{old}}}
$$

Combining the above approximations yields a simpler constrained optimization problem:

$$
\max_{\theta'} \nabla_{\theta} L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta})_{\theta =\theta_{old}}^\top(\theta' - \theta_{\text{old}}) \quad \text{subject to} \quad \frac{1}{2}(\theta' - \theta_{\text{old}})^\top F(\theta_{\text{old}})(\theta' - \theta_{\text{old}}) \leq \delta
$$

Using the method of Lagrange multipliers, we derive the optimal policy update step explicitly as:

$$
\theta_{\text{new}} = \theta_{\text{old}} + \sqrt{\frac{2\delta}{g^\top F^{-1} g}} F^{-1} g
$$

where:

- $g = \nabla_{\theta} L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta})_{\theta =\theta_{old}}$ is the policy gradient evaluated at $\theta_{old}$ by using the logrithm trick.
- $F$ is the Fisher Information Matrix evaluated at $\theta_{\text{old}}$.

In practice, directly computing the inverse $F^{-1}$ is computationally expensive. Thus, TRPO uses the conjugate gradient method to approximate the product $F^{-1} g$ efficiently without explicitly computing the inverse.

---
**TRPO Algorithm**

**Given:** initial policy parameters $\theta_0$, KL-divergence constraint parameter $\delta$

**for** iteration $k=0,1,2,\dots$ **do**:

1. **Collect trajectories** by executing policy $\pi_{\theta_k}(a|s)$.

2. **Compute advantages** $A^{\pi_{\theta_k}}(s,a)$ using collected data.

3. **Compute policy gradient:**
$$
g = \mathbb{E}_{s,a \sim \pi_{\theta_k}}\left[\nabla_{\theta'}\log\pi_{\theta'}(a|s)\big|_{\theta'=\theta_k} A^{\pi_{\theta_k}}(s,a)\right]
$$

4. **Estimate Fisher Information Matrix (FIM)**:
$$
F(\theta_k) = \mathbb{E}_{s,a \sim \pi_{\theta_k}}\left[\nabla_{\theta_k}\log\pi_{\theta_k}(a|s)\nabla_{\theta_k}\log\pi_{\theta_k}(a|s)^\top\right]
$$

5. **Compute policy update direction** by approximately solving:
$$
F(\theta_k)x = g
$$
using the **conjugate gradient method**.

6. **Update policy parameters**:
$$
\theta_{k+1} = \theta_k + \sqrt{\frac{2\delta}{g^\top x}}\,x
$$

**end for**  


In [None]:
#Simplified TRPO
class TRPOAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, max_kl=1e-2):
        self.policy_net = PolicyNetwork(state_dim, action_dim).to(device)
        self.gamma = gamma
        self.max_kl = max_kl

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        probs = self.policy_net(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)
    
    def compute_loss_and_kl(self, trajectories, old_log_probs):
        # trajectories: list of episodes (each episode: list of (s, a, log_prob, r, ...))
        # Compute surrogate loss and mean KL divergence between old and new policies.
        all_loss = []
        all_kl = []
        for traj, old_lp in zip(trajectories, old_log_probs):
            returns = []
            G = 0
            for (_, _, _, r, _, _, _) in reversed(traj):
                G = r + self.gamma * G
                returns.insert(0, G)
            returns = torch.tensor(returns, dtype=torch.float).to(device)
            # Normalize returns:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
            for (s, a, _, r, _, _, _), old_logp in zip(traj, old_lp):
                s = torch.FloatTensor(s).unsqueeze(0).to(device)
                probs = self.policy_net(s)
                m = Categorical(probs)
                new_logp = m.log_prob(torch.tensor(a).to(device))
                ratio = torch.exp(new_logp - old_logp)
                # Here we use the full return as an estimator of Q
                loss = -ratio * returns[0]  # simplified; normally use advantage
                all_loss.append(loss)
                kl = torch.distributions.kl_divergence(
                    Categorical(probs_old), m)  # we need old probs
                all_kl.append(kl)
        loss_mean = torch.stack(all_loss).mean()
        kl_mean = torch.stack(all_kl).mean()
        return loss_mean, kl_mean

    def update(self, trajectories, old_log_probs):
        # TRPO update: Solve constrained optimization problem.
        # In our simplified version, we just show the outline.
        # A full implementation requires computing the natural gradient via conjugate gradient
        # and then performing a line search to enforce KL constraint.
        loss, kl = self.compute_loss_and_kl(trajectories, old_log_probs)
        # Here we would compute the natural gradient and then update parameters while checking KL.
        # For demonstration, we simply take a small gradient step if kl < max_kl.
        self.policy_net.zero_grad()
        loss.backward()
        # Compute mean KL divergence (placeholder; normally computed during conjugate gradient)
        if kl < self.max_kl:
            for p in self.policy_net.parameters():
                p.data -= 1e-2 * p.grad  # note the minus sign since loss is negative surrogate.
        # In a real TRPO implementation, we would use a second-order method.
        

class TRPOAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, max_kl=1e-2):
        self.policy_net = PolicyNetwork(state_dim, action_dim).to(device)
        # Add a value network for TD targets
        self.value_net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        ).to(device)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=1e-3)
        self.gamma = gamma
        self.max_kl = max_kl

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        probs = self.policy_net(state)
        m = Categorical(probs)
        action = m.sample()
        # Also return value estimate
        value = self.value_net(state)
        return action.item(), m.log_prob(action), value

    def compute_loss_and_kl_td(self, trajectories, old_log_probs):
        # Compute surrogate loss and KL using TD targets instead of Monte Carlo returns
        all_loss = []
        all_kl = []
        all_value_loss = []
        
        for traj, old_lp in zip(trajectories, old_log_probs):
            states = torch.FloatTensor([t[0] for t in traj]).to(device)
            actions = torch.LongTensor([t[1] for t in traj]).to(device)
            rewards = torch.FloatTensor([t[3] for t in traj]).to(device)
            next_states = torch.FloatTensor([t[5] for t in traj]).to(device)
            dones = torch.FloatTensor([float(t[6]) for t in traj]).to(device)
            
            # Get current values and old probs
            current_values = self.value_net(states).squeeze()
            
            # Calculate TD targets with torch.no_grad()
            with torch.no_grad():
                next_values = self.value_net(next_states).squeeze()
                next_values = next_values * (1 - dones)
                td_targets = rewards + self.gamma * next_values
                # Calculate advantages and detach from computational graph
                advantages = (td_targets - current_values).detach()
                # Normalize advantages
                if len(advantages) > 1:  # Only normalize if we have enough samples
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Compute value loss
            value_loss = nn.functional.mse_loss(current_values, td_targets)
            all_value_loss.append(value_loss)
            
            # Compute policy loss and KL divergence
            for i, ((s, a, _, _, _, _, _), old_logp) in enumerate(zip(traj, old_lp)):
                s = torch.FloatTensor(s).unsqueeze(0).to(device)
                a = torch.LongTensor([a]).to(device)
                
                # Get current policy
                probs = self.policy_net(s)
                m = Categorical(probs)
                new_logp = m.log_prob(a)
                
                # Get old policy distribution for KL calculation
                old_probs = torch.zeros_like(probs)  # Placeholder
                # In practice, you would store the old policy distribution
                # Here we're simplifying by assuming we have access to old_probs
                old_m = Categorical(old_probs)
                old_logp = old_m.log_prob(a)
                
                # Compute ratio and surrogate loss
                ratio = torch.exp(new_logp - old_logp)
                # Use the advantage from TD error
                loss = -ratio * advantages[i]
                all_loss.append(loss)
                
                # Compute KL divergence
                kl = torch.distributions.kl_divergence(old_m, m)
                all_kl.append(kl)
        
        # Average losses and KL
        loss_mean = torch.stack(all_loss).mean()
        kl_mean = torch.stack(all_kl).mean()
        value_loss_mean = torch.stack(all_value_loss).mean()
        
        return loss_mean, kl_mean, value_loss_mean

    def update(self, trajectories, old_log_probs):
        # First update value network
        self.value_optimizer.zero_grad()
        _, _, value_loss = self.compute_loss_and_kl_td(trajectories, old_log_probs)
        value_loss.backward()
        self.value_optimizer.step()
        
        # Then update policy with TRPO
        loss, kl, _ = self.compute_loss_and_kl_td(trajectories, old_log_probs)
        
        # In a full TRPO implementation:
        # 1. Compute policy gradient
        # 2. Compute Fisher Information Matrix
        # 3. Solve Fx = g using conjugate gradient
        # 4. Perform line search to ensure KL constraint
        
        # For our simplified version:
        self.policy_net.zero_grad()
        loss.backward()
        
        # Only update if KL constraint is satisfied
        if kl < self.max_kl:
            for p in self.policy_net.parameters():
                p.data -= 1e-2 * p.grad  # Note the minus sign for gradient descent

### Proximal Policy Optimization (PPO)

Trust Region Policy Optimization (TRPO) solves the following constrained optimization problem:

$$
\max_{\theta'} L_{\pi_{\theta_{\text{old}}}}(\pi_{\theta'}) = \mathbb{E}_{s,a\sim\pi_{\theta_{\text{old}}}}\left[\frac{\pi_{\theta'}(a|s)}{\pi_{\theta_{\text{old}}}(a|s)}A^{\pi_{\theta_{\text{old}}}}(s,a)\right]
$$

subject to the KL-divergence constraint:

$$
\mathbb{E}_{s\sim d^{\pi_{\theta_{\text{old}}}}}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(\cdot|s)||\pi_{\theta'}(\cdot|s))]\leq \delta
$$

TRPO is computationally expensive due to the constraint on KL divergence. PPO simplifies TRPO by using an easier-to-compute penalty mechanism that achieves similar stability without explicitly enforcing KL constraints.

Specifically, PPO introduces the following "surrogate objective" (without explicit constraints):

$$
L^{\text{PPO}}(\theta) = \mathbb{E}_{s,a\sim\pi_{\theta_{\text{old}}}}\left[
r(\theta)A^{\pi_{\theta_{\text{old}}}}(s,a) - \beta D_{\text{KL}}\bigl(\pi_{\theta_{\text{old}}}(\cdot|s)||\pi_{\theta}(\cdot|s)\bigr)
\right]
$$
While the explicit KL penalty term in PPO is easy, tuning the coefficient $\beta$ can be tricky. PPO further simplifies this by introducing a "clipped" surrogate objective:

$$
L^{\text{CLIP}}(\theta) = \mathbb{E}_{s,a\sim\pi_{\theta_{\text{old}}}}\left[
\min\left(
r(\theta)A^{\pi_{\theta_{\text{old}}}}(s,a),\,\text{clip}(r(\theta),1-\epsilon,1+\epsilon)A^{\pi_{\theta_{\text{old}}}}(s,a)
\right)
\right]
$$
This removes the need for explicitly adjusting a penalty coefficient $\beta$, thus making PPO easier to implement and tune.

---

#### How PPO Clipping Acts as an Implicit Penalty

Although PPO doesn’t explicitly add a penalty term like KL divergence, the clipping function implicitly penalizes overly large policy changes. Here's how:

(1) When Advantage $A^{\pi_{\theta_k}}(s,a)>0$ (action better than average):

- If $r(\theta)$ is too large (> $1+\epsilon$), the objective clips and doesn't reward further increases.
- If $r(\theta)$ is too small (<$1-\epsilon$), it naturally reduces the objective since we're multiplying a smaller ratio with a positive advantage.

(2) When Advantage $A^{\pi_{\theta_k}}(s,a)<0$ (action worse than average):

- If $r(\theta)$ is too small (<$1-\epsilon$), it clips at $1-\epsilon$, Now multiplying a negative advantage by a larger number ($1-\epsilon$) creates a less-negative objective than the original ratio would, thus effectively limiting the "benefit" of aggressively reducing the probability of bad actions (but still allowing a controlled reduction).

- If $r(\theta)$ is becomes large (> $1+\epsilon$) for negative advantages, PPO explicitly does not clip upwards. Here, the large ratio multiplied by negative advantage naturally yields a very negative objective value, strongly discouraging that undesirable policy update.


Thus, clipping effectively serves as an implicit "penalty" mechanism, guiding the policy to update cautiously.

---
**PPO Algorithm**

**Given:** Initial policy parameters $\theta_0$, clipping parameter $\epsilon$ (e.g., 0.2), learning rate $\eta$

**for** iteration $k = 0, 1, 2, \dots$ **do**:

1. Collect trajectories using policy $\pi_{\theta_k}(a|s)$.

2. Compute advantage estimates $A^{\pi_{\theta_k}}(s,a)$.

3. **for** each gradient update epoch **do**:

   - Compute probability ratios:
     $$
     r(\theta) = \frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)}
     $$

   - Maximize the clipped PPO objective:
     $$
     L^{\text{CLIP}}(\theta) = \mathbb{E}_{s,a\sim \pi_{\theta_k}}\left[
     \min\left(r(\theta)A^{\pi_{\theta_k}}(s,a),\,\text{clip}(r(\theta),1-\epsilon,1+\epsilon)A^{\pi_{\theta_k}}(s,a)\right)
     \right]
     $$

   by performing gradient ascent step:
   $$
   \theta \leftarrow \theta + \eta \nabla_{\theta} L^{\text{CLIP}}(\theta)
   $$

3. Update the old policy:
   $$
   \theta_{k+1}\leftarrow \theta
   $$

**end for**

In [None]:
class PPOAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=3e-4, gamma=0.99, clip_epsilon=0.2, update_epochs=4):
        self.ac_net = ActorCriticNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.optimizer = optim.Adam(self.ac_net.parameters(), lr=lr)
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.update_epochs = update_epochs

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        policy, value = self.ac_net(state)
        m = Categorical(policy)
        action = m.sample()
        return action.item(), m.log_prob(action), value

    def compute_returns_and_advantages(self, trajectories):
        all_returns = []
        all_advantages = []
        for traj in trajectories:
            returns = []
            G = 0
            for (_, _, _, r, _, _, _) in reversed(traj):
                G = r + self.gamma * G
                returns.insert(0, G)
            returns = torch.tensor(returns, dtype=torch.float).to(device)
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
            advantages = []
            for (i, (_, _, _, _, value, _, _)) in enumerate(traj):
                advantages.append(returns[i] - value.item())
            all_returns.append(returns)
            all_advantages.append(torch.tensor(advantages, dtype=torch.float).to(device))
        return all_returns, all_advantages

    def update(self, trajectories):
        states, actions, old_log_probs, returns, advantages = [], [], [], [], []
        for traj in trajectories:
            for (s, a, logp, r, v, s_next, done) in traj:
                states.append(s)
                actions.append(a)
                old_log_probs.append(logp)
            ret, adv = self.compute_returns_and_advantages([traj])
            returns += ret.tolist()
            advantages += adv.tolist()
        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        old_log_probs = torch.stack(old_log_probs).to(device)
        returns = torch.tensor(returns, dtype=torch.float).to(device)
        advantages = torch.tensor(advantages, dtype=torch.float).to(device)
        
        for _ in range(self.update_epochs):
            policy, values = self.ac_net(states)
            m = Categorical(policy)
            new_log_probs = m.log_prob(actions)
            ratio = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = nn.functional.mse_loss(values.squeeze(), returns)
            entropy_loss = -m.entropy().mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        
    # Add this method to the PPOAgent class
    def update_td(self, transitions):
        # Extract all data from transitions into tensors
        states = torch.FloatTensor([t[0] for t in transitions]).to(device)
        actions = torch.LongTensor([t[1] for t in transitions]).to(device)
        old_log_probs = torch.stack([t[2] for t in transitions]).to(device)
        rewards = torch.FloatTensor([t[3] for t in transitions]).to(device)
        values = torch.stack([t[4] for t in transitions]).to(device).squeeze()
        next_states = torch.FloatTensor([t[5] for t in transitions]).to(device)
        dones = torch.FloatTensor([float(t[6]) for t in transitions]).to(device)
        
        # Calculate TD targets and advantages
        with torch.no_grad():
            _, next_values = self.ac_net(next_states)
            next_values = next_values.squeeze()
            next_values = next_values * (1 - dones)
            td_targets = rewards + self.gamma * next_values
            advantages = td_targets - values.detach()
            # Optional: Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Perform multiple epochs of updates (typical for PPO)
        for _ in range(self.update_epochs):
            # Get current policy and values
            policy, current_values = self.ac_net(states)
            m = Categorical(policy)
            new_log_probs = m.log_prob(actions)
            
            # Calculate ratios and PPO clipped objective
            ratios = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss using TD targets
            value_loss = nn.functional.mse_loss(current_values.squeeze(), td_targets)
            
            # Optional: Add entropy bonus for exploration
            entropy_loss = -m.entropy().mean() * 0.01
            
            # Total loss
            loss = policy_loss + 0.5 * value_loss + entropy_loss
            
            # Update network
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
