# PPO

In [3]:
import torch
import torch.nn as nn
import gymnasium as gym 
import numpy as np 

import torch.optim as optim

from sklearn.preprocessing import StandardScaler

# distribution functions 
from torch.distributions import Categorical 

import sys
sys.path.append('../')


In [None]:
env = gym.make("Blackjack-v1")
policy = PPOPolicy(3, 2)
optimizer = optim.Adam(policy.parameters(), lr=1e-3)

clip_eps = 0.2
gamma = 0.99
epochs = 100_000
batch_size = 32
ppo_epochs = 4

for epoch in range(epochs):
    states, actions, log_probs, rewards, dones, values = [], [], [], [], [], []
    
    for _ in range(batch_size):
        obs, _ = env.reset()
        done = False
        while not done:
            state_tensor = torch.FloatTensor(obs).unsqueeze(0)
            action_probs, state_value = policy(state_tensor)
            dist = Categorical(action_probs)
            action = dist.sample()

            next_obs, reward, done, _, _ = env.step(action.item())
            reward *= 10

            states.append(state_tensor)
            actions.append(action)
            log_probs.append(dist.log_prob(action))
            rewards.append(reward)
            dones.append(done)
            
            # Detach value to prevent graph reuse error
            values.append(state_value.squeeze().detach())

            obs = next_obs

    # Compute returns and advantages
    returns = []
    discounted_sum = 0
    for r, d in zip(reversed(rewards), reversed(dones)):
        if d: discounted_sum = 0
        discounted_sum = r + gamma * discounted_sum
        returns.insert(0, discounted_sum)

    returns = torch.tensor(returns, dtype=torch.float32)
    values = torch.stack(values)
    advantages = returns - values  # values already detached above

    # Convert to tensors
    states = torch.cat(states)
    actions = torch.tensor(actions)
    old_log_probs = torch.stack(log_probs).detach()

    # PPO update
    for _ in range(ppo_epochs):
        action_probs, state_values = policy(states)
        dist = Categorical(action_probs)
        new_log_probs = dist.log_prob(actions)
        ratio = (new_log_probs - old_log_probs).exp()

        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages

        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = nn.functional.mse_loss(state_values.squeeze(), returns)

        loss = policy_loss + 0.5 * value_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if epoch % 100 == 0:
        avg_return = returns.mean().item()
        print(f"Epoch {epoch} | Avg Return: {avg_return:.2f}")


Epoch 0 | Avg Return: -7.20
Epoch 100 | Avg Return: -2.82
Epoch 200 | Avg Return: -1.41
Epoch 300 | Avg Return: -2.19
Epoch 400 | Avg Return: -3.85
Epoch 500 | Avg Return: -4.12
Epoch 600 | Avg Return: -2.88
Epoch 700 | Avg Return: -3.33
Epoch 800 | Avg Return: -3.84
Epoch 900 | Avg Return: -4.86
Epoch 1000 | Avg Return: 1.33
Epoch 1100 | Avg Return: -1.86
Epoch 1200 | Avg Return: -0.41
Epoch 1300 | Avg Return: -4.18
Epoch 1400 | Avg Return: -2.11
Epoch 1500 | Avg Return: 0.01
Epoch 1600 | Avg Return: -3.26
Epoch 1700 | Avg Return: -4.57
Epoch 1800 | Avg Return: -1.98
Epoch 1900 | Avg Return: -2.50
