In [None]:
import warnings
warnings.filterwarnings('ignore')

### Run in collab
<a href="https://colab.research.google.com/github/racousin/rl_introduction/blob/master/notebooks/5_policy_gradient-reinforce.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import gymnasium as gym

In [None]:
# We will experiment our algo with CartPole
env = gym.make('CartPole-v0')

# module9_exercise6: Actor Cirtic

### Objective
Here we present an alternative of Q learning: policy gradient algorithm

**Complete the TODO steps! Good luck!**

# Policy gradient
In policy gradient, we parametrize directly the policy $\pi_\theta$. It's especially welcome when the action space is continuous; in that case greedy policy based on Q-learning need to compute the $argmax_a Q(s,a)$. This could be pretty tedious. More generally, policy gradient algorithms are better to explore large state-action spaces.

$J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}}[{G(\tau)}]$

We can proof  that:


$\nabla_{\theta} J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}}[{\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t |s_t) G(\tau)}]$ 

1. In discrete action space

we parametrize $\pi$ with $\theta$, such as $\pi_\theta : S \rightarrow [0,1]^{dim(A)}$ and $\forall s$ $\sum \pi_\theta(s) = 1$.


2. In continous action space

we parametrize $\pi$ with $\theta$, such as $\pi_\theta : S \rightarrow \mu^{dim(A)} \times \sigma^{dim(A)} =  \mathbb{R}^{dim(A)} \times \mathbb{R}_{+,*}^{dim(A)}$



In torch, it is easier to pass the loss than the gradient.
1. It is possible to show that the loss for discrete action ($1,...,N$) with softmax policy is weighted negative binary crossentropy:
$-G\sum_{j=1}^N[a^j\log(\hat{a}^j) + (1-a^j)\log(1 - \hat{a}^j)]$

with:
$a^j=1$ if $a_t = j$, $0$ otherwise.

$\hat{a}^j = \pi_\theta(s_t)^j$.

$G$ is the discounted empirical return $G_t = \sum_{k=0}^{T-t-1} \gamma^k R_{t+k+1}$ from state $s_t$ and $a_t$


2. It is possible to show that the loss for conitnous action ($1,...,N$) with multivariate Gaussian (identity Covariance) policy is given by:

$-G\sum_{j=1}^N[(a^j - \hat{a}^j)^2]$

$\hat{a}^j = \pi_\theta(s_t)^j$.



see https://aleksispi.github.io/assets/pg_autodiff.pdf for more explanation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Categorical
import numpy as np
from typing import Tuple, List, Optional, Union
import gym
from pathlib import Path

class Actor(nn.Module):
    """Actor network for both discrete and continuous action spaces."""
    def __init__(
        self, 
        state_dim: int, 
        action_dim: int,
        hidden_dims: List[int] = [64, 64],
        continuous: bool = False
    ):
        super().__init__()
        
        # Build network layers
        layers = []
        prev_dim = state_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
            
        self.base = nn.Sequential(*layers)
        self.continuous = continuous
        
        if continuous:
            # For continuous actions, output mean and log_std
            self.mean = nn.Linear(prev_dim, action_dim)
            self.log_std = nn.Parameter(torch.zeros(1, action_dim))
        else:
            # For discrete actions, output action probabilities
            self.policy = nn.Linear(prev_dim, action_dim)
    
    def forward(self, state: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        x = self.base(state)
        
        if self.continuous:
            mean = self.mean(x)
            std = self.log_std.exp()
            return mean, std
        else:
            return F.softmax(self.policy(x), dim=-1)

class Critic(nn.Module):
    """Critic network that estimates state values."""
    def __init__(self, state_dim: int, hidden_dims: List[int] = [64, 64]):
        super().__init__()
        
        layers = []
        prev_dim = state_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
            
        layers.append(nn.Linear(prev_dim, 1))
        self.value_net = nn.Sequential(*layers)
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.value_net(state).squeeze(-1)

class ActorCriticAgent:
    def __init__(
        self,
        env: gym.Env,
        actor_lr: float = 3e-4,
        critic_lr: float = 1e-3,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        entropy_coef: float = 0.01,
        value_loss_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.env = env
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.device = device
        
        # Determine action space type
        self.continuous = isinstance(env.action_space, gym.spaces.Box)
        state_dim = env.observation_space.shape[0]
        
        if self.continuous:
            action_dim = env.action_space.shape[0]
            self.action_low = torch.tensor(env.action_space.low, device=device)
            self.action_high = torch.tensor(env.action_space.high, device=device)
        else:
            action_dim = env.action_space.n
        
        # Initialize networks
        self.actor = Actor(state_dim, action_dim, continuous=self.continuous).to(device)
        self.critic = Critic(state_dim).to(device)
        
        # Setup optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        
        self.trajectory: List = []
    
    def select_action(self, state: np.ndarray) -> Tuple[Union[int, np.ndarray], float]:
        """Select action using the current policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            if self.continuous:
                mean, std = self.actor(state_tensor)
                dist = Normal(mean, std)
                action = dist.sample()
                log_prob = dist.log_prob(action).sum(dim=-1)
                action = torch.clamp(action, self.action_low, self.action_high)
                return action.cpu().numpy()[0], log_prob.cpu().item()
            else:
                probs = self.actor(state_tensor)
                dist = Categorical(probs)
                action = dist.sample()
                return action.cpu().item(), dist.log_prob(action).cpu().item()
    
    def compute_gae(
        self,
        rewards: torch.Tensor,
        values: torch.Tensor,
        next_value: torch.Tensor,
        dones: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute Generalized Advantage Estimation and returns."""
        advantages = torch.zeros_like(rewards)
        returns = torch.zeros_like(rewards)
        running_return = next_value
        running_advantage = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_non_terminal = 1.0 - dones[-1]
                next_value = next_value
            else:
                next_non_terminal = 1.0 - dones[t + 1]
                next_value = values[t + 1]
            
            running_return = rewards[t] + self.gamma * next_non_terminal * running_return
            returns[t] = running_return
            
            td_error = rewards[t] + self.gamma * next_non_terminal * next_value - values[t]
            running_advantage = td_error + self.gamma * self.gae_lambda * next_non_terminal * running_advantage
            advantages[t] = running_advantage
            
        return advantages, returns
    
    def update(self) -> Tuple[float, float, float]:
        """Update policy and value function using collected trajectory."""
        # Convert trajectory to tensors
        states = torch.FloatTensor([t[0] for t in self.trajectory]).to(self.device)
        if self.continuous:
            actions = torch.FloatTensor([t[1] for t in self.trajectory]).to(self.device)
        else:
            actions = torch.LongTensor([t[1] for t in self.trajectory]).to(self.device)
        log_probs = torch.FloatTensor([t[2] for t in self.trajectory]).to(self.device)
        rewards = torch.FloatTensor([t[3] for t in self.trajectory]).to(self.device)
        dones = torch.FloatTensor([t[4] for t in self.trajectory]).to(self.device)
        
        # Compute values and next value
        with torch.no_grad():
            values = self.critic(states)
            if len(self.trajectory) < self.env._max_episode_steps:
                next_value = self.critic(torch.FloatTensor(self.trajectory[-1][5]).unsqueeze(0).to(self.device))
            else:
                next_value = torch.zeros(1).to(self.device)
        
        # Compute advantages and returns
        advantages, returns = self.compute_gae(rewards, values, next_value, dones)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Compute actor loss
        if self.continuous:
            mean, std = self.actor(states)
            dist = Normal(mean, std)
            new_log_probs = dist.log_prob(actions).sum(dim=-1)
            entropy = dist.entropy().mean()
        else:
            probs = self.actor(states)
            dist = Categorical(probs)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()
        
        ratio = torch.exp(new_log_probs - log_probs)
        actor_loss = -(advantages * ratio).mean()
        
        # Compute critic loss
        value_pred = self.critic(states)
        value_loss = F.mse_loss(value_pred, returns)
        
        # Compute total loss
        total_loss = actor_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy
        
        # Update actor and critic
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        total_loss.backward()
        
        # Clip gradients
        actor_grad_norm = nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        critic_grad_norm = nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        
        self.actor_optimizer.step()
        self.critic_optimizer.step()
        
        # Clear trajectory
        self.trajectory = []
        
        return actor_loss.item(), value_loss.item(), entropy.item()
    
    def save(self, path: str):
        """Save actor and critic models."""
        torch.save({
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
        }, path)
    
    def load(self, path: str):
        """Load actor and critic models."""
        checkpoint = torch.load(path)
        self.actor.load_state_dict(checkpoint['actor_state_dict'])
        self.critic.load_state_dict(checkpoint['critic_state_dict'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])

def train(
    env_name: str,
    num_episodes: int = 1000,
    max_steps: int = 1000,
    update_frequency: int = 2048
) -> Tuple[ActorCriticAgent, List[float]]:
    """Train the Actor-Critic agent."""
    env = gym.make(env_name)
    agent = ActorCriticAgent(env)
    
    episode_rewards = []
    step_counter = 0
    
    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        step = 0
        
        while not done and step < max_steps:
            # Select action
            action, log_prob = agent.select_action(state)
            
            # Take action
            next_state, reward, done, _ = env.step(action)
            
            # Store transition
            agent.trajectory.append((state, action, log_prob, reward, done, next_state))
            episode_reward += reward
            
            # Update if we have enough steps
            step_counter += 1
            if step_counter >= update_frequency:
                actor_loss, value_loss, entropy = agent.update()
                step_counter = 0
                print(f"Episode {episode}, Actor Loss: {actor_loss:.4f}, Value Loss: {value_loss:.4f}, Entropy: {entropy:.4f}")
            
            state = next_state
            step += 1
        
        # Update at end of episode if we have transitions
        if len(agent.trajectory) > 0:
            actor_loss, value_loss, entropy = agent.update()
            step_counter = 0
        
        episode_rewards.append(episode_reward)
        if episode % 10 == 0:
            avg_reward = np.mean(episode_rewards[-10:])
            print(f"Episode {episode}, Average Reward: {avg_reward:.2f}")
    
    return agent, episode_rewards

if __name__ == "__main__":
    # Train on a continuous action space environment
    agent, rewards = train("LunarLanderContinuous-v2", num_episodes=1000)
    
    # Plot training results
    import matplotlib.pyplot as plt
    plt.plot(rewards)
    plt.title("Training Rewards")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.show()