In [36]:
import numpy as np
import torch
from torch import nn
from torch.distributions import Categorical
import gymnasium as gym

class FFNetwork(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.sequential = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, out_dim)
        )
    
    def forward(self, input):
        input = torch.tensor(input)
        if len(input.shape) == 1:
            input = input.unsqueeze(0)
        return self.sequential(input)

In [None]:
def create_trajectories(env, actor, timesteps):
    states = []
    state_terminal_indicator = []
    state_terminal_indicator.append(True)
    actions = []
    rewards = []
    logodds = []
    obs, _ = env.reset()
    while (timesteps >= 0) or (not terminated): # let the last episode finish
        timesteps = timesteps - 1
        states.append(obs)
        logits = actor(obs)
        dist = Categorical(logits=logits)
        action = dist.sample()

        logodds.append(dist.log_prob(action).item())
        actions.append(action.item())
        obs, reward, terminated, _, _ = env.step(actions[-1])
        rewards.append(reward)
        state_terminal_indicator.append(terminated)
        if terminated:
            obs, _ = env.reset()

    return np.array(states), np.array(actions), np.array(logodds), np.array(rewards), np.array(state_terminal_indicator)

def validate(env, actor):
    seeds = [3, 17, 42, 8, 29, 567, 91, 1400, 67, 23, 888, 5, 37, 72, 59990]
    score = 0
    for seed in seeds:
        obs, _ = env.reset(seed=seed)
        terminated = False
        while not terminated:
            obs, reward, terminated, _, _ = env.step(actor(obs).argmax().item())
            score += reward
    return score/len(seeds)

In [39]:
def GAE(critic, states, starts, rewards, gamma=0.99, lbda=0.95):
    timesteps = len(states)
    A = np.zeros(timesteps)
    delta = np.zeros(timesteps)
    
    for t in range(timesteps - 1):
        delta[t] = rewards[t] + gamma * critic(states[t+1]).item() * int(not starts[t+1]) - critic(states[t]).item()
    
    for t in reversed(range(timesteps - 1)):
        A[t] = delta[t] + gamma * lbda * A[t+1] * int(not starts[t+1])

    return A

In [40]:
# Figure out that stuff with the adv - states mismatch

def UpdatePPO(actor, critic, optimizer_actor, optimizer_critic,
              states, advantages, actions, rewards, old_logodds,
              clip_ratio=0.2,
              epochs=10,
              batch_size=64):
    for _ in range(epochs):
        indeces = np.arange(len(states))
        np.random.shuffle(indeces)
        for start in range(0, len(states), batch_size):
            end = start + batch_size
            idx = indeces[start:end]
            s = states[idx]
            a = torch.tensor(actions[idx], dtype=torch.int64)
            lo = torch.tensor(old_logodds[idx], dtype=torch.float32) 
            r = torch.tensor(rewards[idx], dtype=torch.float32) 
            adv = torch.tensor(advantages[idx], dtype=torch.float32)

            optimizer_actor.zero_grad()
            optimizer_critic.zero_grad()

            ratio = torch.exp(Categorical(logits=actor(s)).log_prob(a) - lo)
            loss1, loss2 = ratio * adv, torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
            a_loss = -1 * torch.min(loss1, loss2).mean()
            v_loss = ((critic(s) - r)**2).mean()

            a_loss.backward()
            v_loss.backward()
            optimizer_actor.step()
            optimizer_critic.step()

In [42]:
import gymnasium as gym

def train_ppo(env_name="CartPole-v1", render="human", iterations=200, steps_per_iter=2048, k_epochs=10, lr_actor=3e-4, lr_critic=3e-4, clip_ratio=0.2, batch_size=64, gamma=0.99, lbda=0.95, device="cpu"):
    env = gym.make(env_name, render_mode=render)

    actor = FFNetwork(env.observation_space.shape[0], env.action_space.n)
    critic = FFNetwork(env.observation_space.shape[0], 1)
    optimizer_actor = torch.optim.Adam(actor.parameters(), lr=lr_actor)
    optimizer_critic = torch.optim.Adam(critic.parameters(), lr=lr_critic)

    best_score = 0
    best_model = actor.state_dict()

    avg_rewards = []
    for i in range(iterations):
        states, actions, old_logodds, rewards, starts = create_trajectories(env, actor, steps_per_iter)
        advantages = GAE(critic, states, starts, rewards)
        UpdatePPO(actor, critic, optimizer_actor, optimizer_critic,
              states, advantages, actions, rewards, old_logodds,
              clip_ratio=clip_ratio,
              epochs=k_epochs,
              batch_size=batch_size)
        s = validate(env, actor)
        avg_rewards.append(s)
        print(f"Iteration {i} complete. score: {s}")
        if s >= best_score:
            best_score = s
            best_model = actor.state_dict()
    
    actor.load_state_dict(best_model)
    env.close()
    return actor, avg_rewards
        

In [49]:
actor, avg_rewards = train_ppo(render=None, iterations=150, steps_per_iter=2500, k_epochs=10, lr_actor=1e-6, lr_critic=1e-6, clip_ratio=0.2)

Iteration 0 complete. score: 9.4
Iteration 1 complete. score: 9.533333333333333
Iteration 2 complete. score: 9.666666666666666
Iteration 3 complete. score: 9.866666666666667
Iteration 4 complete. score: 10.2
Iteration 5 complete. score: 10.6
Iteration 6 complete. score: 11.0
Iteration 7 complete. score: 11.866666666666667
Iteration 8 complete. score: 12.4
Iteration 9 complete. score: 13.0
Iteration 10 complete. score: 13.733333333333333
Iteration 11 complete. score: 14.333333333333334
Iteration 12 complete. score: 15.466666666666667
Iteration 13 complete. score: 16.333333333333332
Iteration 14 complete. score: 17.2
Iteration 15 complete. score: 18.133333333333333
Iteration 16 complete. score: 18.666666666666668
Iteration 17 complete. score: 20.133333333333333
Iteration 18 complete. score: 22.2
Iteration 19 complete. score: 23.266666666666666
Iteration 20 complete. score: 24.733333333333334
Iteration 21 complete. score: 25.333333333333332
Iteration 22 complete. score: 26.266666666666666

In [None]:
import pandas as pd
import plotly.express as px

df = pd.DataFrame({'Iteration': range(1, len(avg_rewards) + 1), 'Score': avg_rewards})
fig = px.line(df, x='Iteration', y='Score', title='Score per Iteration', markers=True, 
              width=900,
              height=500)

fig.show()

In [52]:
def display(env_name, actor):
    env = gym.make(env_name, render_mode="human")
    obs, _ = env.reset()
    terminated = False
    while not terminated:
        obs, reward, terminated, _, _ = env.step(actor(obs).argmax().item())
    env.close()

display("CartPole-v1", actor)