In [1]:
import gym_agent as ga
import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.distributions import Categorical

import numpy as np

from torchsummary import summary

import os

from tqdm.notebook import tqdm_notebook as tqdm

from ppo import PPO

pygame 2.5.2 (SDL 2.28.2, Python 3.11.0)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class StopExecution(Exception):
    def _render_traceback_(self):
        return []

In [3]:
class Encoded(nn.Module):
    def __init__(self, n_inp, features):
        super().__init__()

        layer_sizes = [n_inp] + features

        self.encoded = nn.Sequential()

        for i in range(len(layer_sizes) - 1):
            self.encoded.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
            self.encoded.append(nn.ReLU())
        
    def forward(self, x):
        return self.encoded(x)

In [4]:
class Actor(nn.Module):
    def __init__(self, n_inp: int, n_action, features = [128, 128]):
        super().__init__()

        self.encoded = Encoded(n_inp, features)
        self.actor = nn.Linear(features[-1], n_action)

    def forward(self, X: torch.Tensor):
        X = F.softmax(self.actor(self.encoded(X)), dim=-1)

        return X

In [5]:
class Critic(nn.Module):
    def __init__(self, n_inp: int, features = [128, 128]):
        super().__init__()

        self.encoded = Encoded(n_inp, features)
        
        self.critic = nn.Linear(features[-1], 1)

    def forward(self, X: torch.Tensor):
        return self.critic(self.encoded(X))

In [7]:
class Policy(nn.Module):
    def __init__(self, n_inp, n_action, features = [512, 256, 128, 64]):
        super().__init__()
        self.actor = Actor(n_inp, n_action, features)
        self.critic = Critic(n_inp, features)
    
    def action(self, X: torch.Tensor):
        return self.actor(X)
    
    def value(self, X: torch.Tensor):
        return self.critic(X)

In [11]:
class Policy(nn.Module):
    def __init__(self, n_inp, n_action, features = [512, 256, 128, 64]):
        super().__init__()
        self.encoded = Encoded(n_inp, features)
        
        self.actor = nn.Linear(features[-1], n_action)
        self.critic = nn.Linear(features[-1], 1)
    
    def action(self, X: torch.Tensor):
        return F.log_softmax(self.actor(self.encoded(X)), dim=-1)
    
    def value(self, X: torch.Tensor):
        return self.critic(self.encoded(X))

In [12]:
class PPO_Lunar(ga.AgentBase):
    def __init__(
            self, 
            state_shape, 
            action_shape, 
            n_action,
            n_epochs,
            gamma = 0.99,
            lr = 3e-5,
            gae_lambda = 0.9,
            vf_coef = 0.5,
            ent_coef = 0.01,
            policy_clip = 0.2,
            batch_size = 64, 
            device = 'cuda', 
            **kwargs
        ):
        super().__init__(state_shape, action_shape, batch_size, True, device=device,**kwargs)

        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.n_epochs = n_epochs
        self.policy_clip = policy_clip
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef

        features = [128, 128]

        self.lr = lr

        self.policy = Policy(state_shape[0], n_action, features)

        self.optimizer = optim.Adam(self.policy.parameters(), lr)

    def act(self, state: np.ndarray):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)

        probs: torch.Tensor = self.policy.action(state)
        action = Categorical(probs).sample()
        return action.detach().cpu().numpy()[0]
    
    def learn(self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, terminals: torch.Tensor):
        states = states.detach().to(self.device)
        actions = actions.detach().squeeze(1).to(self.device)
        rewards = rewards.detach().squeeze(1).to(self.device)

        
        # values = self.policy.value(states).squeeze(-1).detach().to(self.device)
        old_log_probs = Categorical(self.policy.action(states)).log_prob(actions).detach().to(self.device)

        # returns = self.calc_returns(rewards.squeeze(1)).detach().to(self.device)

        # advantages: torch.Tensor = self.calc_advantages(returns, values).detach().to(self.device)

        advantages, returns = self.calc_advantages_and_returns(states, rewards, terminals, next_states[-1])


        for _ in range(self.n_epochs):

            critic_value: torch.Tensor = self.policy.value(states).squeeze(1)
            new_action_probs: torch.Tensor = self.policy.action(states)

            new_log_probs: torch.Tensor = Categorical(new_action_probs).log_prob(actions)

            prob_ratio = (new_log_probs - old_log_probs).exp()

            #prob_ratio = (new_probs - old_probs).exp()
            weighted_probs = prob_ratio*advantages

            weighted_clipped_probs = torch.clamp(prob_ratio, 1-self.policy_clip,
                    1+self.policy_clip)*advantages
            
            actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()

            critic_loss = F.smooth_l1_loss(returns, critic_value).mean()

            total_loss = actor_loss + self.vf_coef*critic_loss

            self.optimizer.zero_grad()
            # actor_loss.backward()
            # critic_loss.backward()
            total_loss.backward()
            self.optimizer.step()

        
    def calc_advantages_and_returns(self, states: torch.Tensor, rewards: torch.Tensor, terminals: bool, last_state: torch.Tensor):
        if self.gae_lambda == 1:
            returns = self.calc_returns(rewards).detach().to(self.device)
            values = self.policy.value(states).squeeze(-1).detach().to(self.device)
            advantages = self.calc_advantages(returns, values).detach().to(self.device)

            return advantages, returns

        T = len(states)
        
        values: torch.Tensor = self.policy.value(torch.cat([states, last_state.unsqueeze(0)])).squeeze(-1)

        advantages = torch.zeros((T, ), device=self.device)

        with torch.no_grad():
            lastgaelam = 0
            for t in reversed(range(T)):
                delta = rewards[t] - values[t] + self.gamma * values[t+1] * (~terminals[t])

                advantages[t] = lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam * (~terminals[t])

            advantages = (advantages - advantages.mean()) / advantages.std()

            returns = advantages + values[:-1]
            
            returns = (returns - returns.mean()) / returns.std()

        return advantages, returns
    
    def calc_returns(self, rewards):
        returns = torch.zeros_like(rewards).to(self.device)
        R = 0

        for i in reversed(range(len(rewards))):
            R = rewards[i] + R * self.gamma
            returns[i] = R
            
        returns = (returns - returns.mean()) / returns.std()
            
        return returns

    def calc_advantages(self, returns, values):
        advantages = returns - values
        return (advantages - advantages.mean()) / advantages.std()

In [13]:
agent = PPO_Lunar(
    state_shape=(8, ),
    action_shape=(1, ),
    n_action = 4,
    batch_size = None,
    n_epochs = 10,
    lr = 1e-4,
    policy_clip=0.2,
    gamma = 0.99,
    gae_lambda=1
)

agent.apply(ga.init_weights)

In [14]:
env = ga.make('LunarLander-v2')

In [15]:
scores = agent.fit(env, 1000, 500, True, False, save_dir='checkpoints/PPO_Lunar', progress_bar = tqdm)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [17]:
env = ga.make('LunarLander-v2', render_mode='human')
agent.load('checkpoints/PPO_Lunar')
agent.play(env)

55.476866666118724

In [None]:
from stable_baselines3 import PPO, DQN

a = PPO()

a.learn()