# Policy Gradients
While Q learning was all about value functoins and Bellman equations, policy gradients are all about (as you may have guessed) estimating policy gradients to iterate to the best possible policy.

The derivation is quite fun but we know that $\Delta_\theta J(\theta)$  = $= \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)R(\tau) \right] \quad \text{Expression for grad-log-prob}$
- The estimate we'll use is $\hat{g} = \frac{1}{|\mathcal{D}|} \sum_{\tau \in \mathcal{D}} \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t |s_t) R(\tau)$

**The REINFORCE algorithm** is quite simple really
1. Generate sample trajectories 
2. Generate policy gradient
3. Improve policy: $\theta \leftarrow \theta + \alpha \nabla_\theta J(\theta)$ 

In [16]:
import gymnasium as gym
env = gym.make('Acrobot-v1', render_mode="human")
# Testing


In [5]:

observation, info = env.reset()

for _ in range(10):
    action = env.action_space.sample()  # agent policy that uses the observation and info
    observation, reward, terminated, truncated, info = env.step(action)

    if terminated or truncated:
        observation, info = env.reset()

env.close()

error: display Surface quit

In [32]:
import torch.nn as nn
class PolicyModel(nn.Module):
    
    def __init__(self, input_dim, action_space):
        super(PolicyModel, self).__init__()
        
        self.input_dim = input_dim
        self.action_space = action_space
        self.model =nn.Sequential(
            nn.Linear(input_dim, 200),
            nn.ReLU(),
            nn.Linear(200, 300),
            nn.ReLU(),
            nn.Linear(300, action_space),
            nn.Softmax()
        )
        
    def forward(self, input):
        return self.model(input)
     
        

In [33]:
model = PolicyModel(6,3)

In [34]:
import torch
import torch.optim as optim

optimizer = optim.Adam(model.parameters())
episodes=100
epochs = 10
for e in range(epochs):
    d = []
    
    observation, info = env.reset()
    
    for i in range(100):
        
        action = torch.argmax(model(torch.tensor(observation)))
        
        observation, reward, terminated, truncated, info = env.step(action)
        d.append((observation, action, reward))
        if terminated or truncated:
            break
    
    
    print(d[0])
        
        



  input = module(input)


(array([ 0.9991336 ,  0.04161806,  0.9985474 , -0.05388077, -0.03040618,
        0.067844  ], dtype=float32), tensor(1), -1.0)


KeyboardInterrupt: 

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F

# Assume 'model' is your policy network

optimizer = optim.Adam(model.parameters())
episodes = 100
gamma = 0.99  # discount factor

for e in range(episodes):
    episode_data = []
    
    state, info = env.reset()
    
    # Generate an episode
    for t in range(1000):  # 100 is an arbitrary choice for max episode length
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action = torch.argmax(model(state_tensor)).item()
#         action = torch.multinomial(action_probs, 1).item()

        next_state, reward, terminated, truncated, info = env.step(action)
        episode_data.append((state, action, reward))
        
        state = next_state
        if terminated or truncated:
            break

    # Calculate policy gradient and update model
    R = 0
    policy_loss = []
    returns = []
    for s, a, r in reversed(episode_data):
        R = r + gamma * R
        returns.insert(0, R)

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + 1e-7)  # normalize

    
    # Calculate policy gradient and update model
    policy_loss = []
    for (s, a, r), R in zip(episode_data, returns):
        state_tensor = torch.tensor(s, dtype=torch.float32)
        action_probs = model(state_tensor)
        action_prob = action_probs[a]
        # Using unsqueeze to add a dimension
        policy_loss.append(-torch.log(action_prob) * R)
    optimizer.zero_grad()

    # Concatenating the policy loss list
    policy_loss = torch.cat([loss.unsqueeze(0) for loss in policy_loss]).sum()


    policy_loss.backward()
    optimizer.step()

    print(f"Episode {e+1}/{episodes}, Total Reward: {sum([x[2] for x in episode_data])}")


Episode 1/100, Total Reward: -500.0
Episode 2/100, Total Reward: -500.0
Episode 3/100, Total Reward: -500.0
Episode 4/100, Total Reward: -500.0
Episode 5/100, Total Reward: -500.0
Episode 6/100, Total Reward: -500.0
Episode 7/100, Total Reward: -500.0
Episode 8/100, Total Reward: -500.0
Episode 9/100, Total Reward: -500.0
Episode 10/100, Total Reward: -500.0
Episode 11/100, Total Reward: -500.0
Episode 12/100, Total Reward: -500.0
Episode 13/100, Total Reward: -500.0
Episode 14/100, Total Reward: -500.0
Episode 15/100, Total Reward: -500.0
Episode 16/100, Total Reward: -500.0
Episode 17/100, Total Reward: -500.0
Episode 18/100, Total Reward: -500.0
Episode 19/100, Total Reward: -500.0
Episode 20/100, Total Reward: -500.0
Episode 21/100, Total Reward: -500.0
Episode 22/100, Total Reward: -500.0
Episode 23/100, Total Reward: -500.0
Episode 24/100, Total Reward: -500.0
Episode 25/100, Total Reward: -500.0
Episode 26/100, Total Reward: -500.0
Episode 27/100, Total Reward: -500.0
Episode 28