# PPO

### Let's run PPO on LunarLander!

#### Imports

In [1]:
from optimrl.algorithms.ppo import PPOLossFunction, PPOOptimizer
from optimrl.policy import GymPolicy
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import numpy as np
from tqdm import tqdm
import gymnasium as gym
from torch.distributions.normal import Normal

#### Seeds

In [2]:
SEED = 1
EVAL_SEED = 2

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x10785c110>

#### Grab obs + act dims

In [3]:
env = gym.make('HalfCheetah-v4', render_mode="rgb_array")
obs_dims = env.observation_space.shape[-1]
act_dims = env.action_space.shape[-1]
print(f"Obs dims: {obs_dims}, Action dims: {act_dims}")

Obs dims: 17, Action dims: 6


### Create policy

In [4]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class PPOContinuousPolicy(GymPolicy):
    def __init__(
        self,
        obs_dims,
        act_dims
    ):
        super().__init__()
        self.obs_dims = obs_dims
        self.act_dims = act_dims

        self.critic = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 1)),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, self.act_dims), std=0.01),
            nn.Tanh(),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(self.act_dims)))

    def forward(self, obs: torch.Tensor, actions = None, policy_out = {}):
        obs = obs.float()
        values = self.critic(obs)
        action_mean = self.actor_mean(obs)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        dist = Normal(action_mean, action_std)
        entropy = dist.entropy().sum(1)
        if actions is None:
            actions = dist.sample()
        actions = actions.float()
        log_probs = dist.log_prob(actions).sum(1)
        return {"actions":actions, "values":values, "log_probs":log_probs, "dist":dist, "entropy":entropy}

    def act(self, obs: torch.Tensor, actions = None, policy_out = {}):
        with torch.no_grad():
            out = self.forward(torch.from_numpy(obs), actions=actions, policy_out = policy_out)
            return out["actions"].detach().cpu().numpy(), out

#### Train

##### Setup optimizer

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
policy = PPOContinuousPolicy(obs_dims, act_dims)
loss_fn = PPOLossFunction(clip_ratio=0.2, ent_coef=0.01, vf_coef=0.5)
optimizer = PPOOptimizer(policy, loss_fn, pi_lr=2e-4, n_updates=5, num_minibatches=32)

policy = policy.to(device)
optimizer = optimizer.to(device)

##### Create vector envs

In [6]:
def make_env(env_name):
    env = gym.make(env_name, render_mode="rgb_array")
    env = gym.wrappers.ClipAction(env)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env

train_envs = gym.vector.SyncVectorEnv(
    [lambda: make_env('HalfCheetah-v4') for i in range(16)],
)
test_envs = gym.vector.SyncVectorEnv(
    [lambda: make_env('HalfCheetah-v4') for i in range(1)],
)

##### Main loop

In [7]:
NUM_STEPS = 20000
MAX_ENV_STEPS = 128

In [8]:
mean_train_rewards = []
mean_test_rewards = []
bar = tqdm(np.arange(NUM_STEPS))

for i in bar:
    with torch.no_grad():
        train_rollout = optimizer.rollout(train_envs, policy, MAX_ENV_STEPS, SEED, evaluate=False)
        eval_rollout = optimizer.rollout(test_envs, policy, MAX_ENV_STEPS, EVAL_SEED, evaluate=True)
    loss, rewards, stats = optimizer.step(
        train_rollout.to_torch(
            device=optimizer.device
        )
    )
    mean_train_rewards.append(train_rollout.stats["sum_rewards"])
    mean_test_rewards.append(eval_rollout.stats["sum_rewards"])
    if i % 50 == 0:
        print(f"Train: {np.mean(mean_train_rewards[-20:])} Test: {np.mean(mean_test_rewards[-20:])}")

  0%|                                                               | 1/20000 [00:00<1:39:50,  3.34it/s]

Train: -15.560292669697743 Test: -13.475637007546174


  0%|▏                                                             | 51/20000 [00:13<1:27:51,  3.78it/s]

Train: -16.080341515267936 Test: -6.749094461644224


  1%|▎                                                            | 101/20000 [00:26<1:31:02,  3.64it/s]

Train: -6.286536013956112 Test: -13.297368424675323


  1%|▍                                                            | 151/20000 [00:40<1:27:05,  3.80it/s]

Train: 17.12188588134856 Test: 15.293884992322026


  1%|▌                                                            | 201/20000 [00:53<1:28:16,  3.74it/s]

Train: 31.66193736147408 Test: 44.57335925271683


  1%|▊                                                            | 251/20000 [01:06<1:28:01,  3.74it/s]

Train: 59.09524559879422 Test: 63.267981526227594


  2%|▉                                                            | 301/20000 [01:20<1:27:07,  3.77it/s]

Train: 80.53920860057913 Test: 83.72588062746762


  2%|█                                                            | 351/20000 [01:33<1:26:30,  3.79it/s]

Train: 105.90150698669012 Test: 101.48889425174976


  2%|█▏                                                           | 401/20000 [01:46<1:26:08,  3.79it/s]

Train: 129.70300324488798 Test: 139.9520531120921


  2%|█▍                                                           | 451/20000 [01:59<1:26:04,  3.79it/s]

Train: 151.3613456470404 Test: 159.96881396537844


  3%|█▌                                                           | 501/20000 [02:12<1:24:41,  3.84it/s]

Train: 170.90139629529875 Test: 184.07970351334555


  3%|█▋                                                           | 551/20000 [02:26<1:24:14,  3.85it/s]

Train: 181.16992217906565 Test: 176.17593808762498


  3%|█▊                                                           | 601/20000 [02:39<1:24:17,  3.84it/s]

Train: 193.51148027488517 Test: 187.37595691456278


  3%|█▉                                                           | 651/20000 [02:52<1:24:07,  3.83it/s]

Train: 202.93133398651773 Test: 203.12433703594655


  4%|██▏                                                          | 701/20000 [03:05<1:23:36,  3.85it/s]

Train: 224.13374121995813 Test: 223.46556959174646


  4%|██▏                                                          | 733/20000 [03:14<1:25:06,  3.77it/s]

KeyboardInterrupt

