# Recurrent PPO

### Let's run Recurrent PPO on LunarLander!

#### Imports

In [1]:
from optimrl.algorithms.ppo import PPOLossFunction, PPOOptimizer

import gym
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import numpy as np
from tqdm import tqdm

#### Seeds

In [2]:
SEED = 1
EVAL_SEED = 2

random.seed(SEED)
np.random.seed(SEED)
_ = torch.manual_seed(SEED)

#### Grab obs + act dims

In [3]:
env = gym.make("LunarLander-v2")
obs_dims = env.observation_space.shape[-1]
act_dims = env.action_space.n
print(f"Obs dims: {obs_dims}, Action dims: {act_dims}")

Obs dims: 8, Action dims: 4


### Create policy

In [4]:
from typing import Dict, Any

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class PPORecurrentCategoricalPolicy(nn.Module):
    def __init__(
        self,
        obs_dims,
        act_dims,
        obs_hidden_dim: int = 256,
        gru_hidden_dim: int = 128
    ):
        super().__init__()
        self.obs_dims = obs_dims
        self.act_dims = act_dims
        self.obs_hidden_dim = obs_hidden_dim
        self.gru_hidden_dim = gru_hidden_dim

        self.backbone = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, obs_hidden_dim)),
            nn.ReLU(),
            layer_init(nn.Linear(obs_hidden_dim, obs_hidden_dim)),
            nn.ReLU(),
        )
        self.gru = nn.GRUCell(obs_hidden_dim, self.gru_hidden_dim)
        for name, param in self.gru.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0)
            elif "weight" in name:
                nn.init.orthogonal_(param, 1.0)

        self.critic = nn.Sequential(
            layer_init(nn.Linear(self.gru_hidden_dim, 1)),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(self.gru_hidden_dim, self.act_dims)),
        )

    def forward(self, obs: torch.Tensor, policy_out = {}):
        if "hidden" not in policy_out:
            policy_out["hidden"] = torch.zeros((obs.shape[0], self.gru_hidden_dim)).to(obs.device)
        backbone = self.backbone(obs)
        hidden = self.gru(backbone, policy_out["hidden"]).view(obs.shape[0], -1)
        logits = self.actor(hidden)
        values = self.critic(hidden)
        dist = td.Categorical(logits=logits)
        actions = dist.sample()
        entropy = dist.entropy()
        log_probs = dist.log_prob(actions)
        return {
            "actions":actions,
            "values":values,
            "log_probs":log_probs,
            "dist":dist,
            "logits":logits,
            "entropy":entropy,
            "hidden":hidden
        }

    def train_forward(
        self, obs: torch.Tensor, policy_out: Dict[str, Any] = {}
    ) -> Dict[str, Any]:
        num_steps = obs.shape[0]
        num_envs = obs.shape[1]
        logits = []
        values = []
        for i in range(num_steps):
            policy_out = self.forward(obs=obs[i], policy_out=policy_out)
            logits.append(policy_out["logits"])
            values.append(policy_out["values"])
        logits = torch.stack(logits).view(num_steps*num_envs, -1)
        values = torch.stack(values).view(num_steps*num_envs, -1)
        dist = td.Categorical(logits=logits)
        return {"dist":td.Categorical(logits=logits), "entropy": dist.entropy(), "values":values}

    def act(self, obs: torch.Tensor, policy_out = {}):
        with torch.no_grad():
            out = self.forward(torch.from_numpy(obs), policy_out)
            return out["actions"].detach().cpu().numpy(), out

#### Train

##### Setup optimizer

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
policy = PPORecurrentCategoricalPolicy(8, 4)
loss_fn = PPOLossFunction(clip_ratio=0.2, ent_coef=0.01, vf_coef=1.0)
optimizer = PPOOptimizer(policy, loss_fn, pi_lr=0.0002, n_updates=5, num_minibatches=4, recurrent=True)

policy = policy.to(device)
optimizer = optimizer.to(device)

In [6]:
# obs = torch.zeros((100,4,8))
# out = policy.train_forward(obs)

##### Create vector envs

In [7]:
def make_env(env_name):
    env =  gym.make(env_name)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env

train_envs = gym.vector.SyncVectorEnv(
    [lambda: make_env("LunarLander-v2") for i in range(4)],
)
test_envs = gym.vector.SyncVectorEnv(
    [lambda: make_env("LunarLander-v2") for i in range(1)],
)

##### Main loop

In [8]:
NUM_STEPS = 20000
MAX_ENV_STEPS = 128

In [9]:
mean_train_rewards = []
mean_test_rewards = []
bar = tqdm(np.arange(NUM_STEPS))

for i in bar:
    with torch.no_grad():
        train_rollout = optimizer.rollout(train_envs, policy, MAX_ENV_STEPS, SEED, evaluate=False)
        eval_rollout = optimizer.rollout(test_envs, policy, MAX_ENV_STEPS, EVAL_SEED, evaluate=True)
    loss, rewards, stats = optimizer.step(
        train_rollout.to_torch(
            device=optimizer.device
        )
    )
    mean_train_rewards.append(train_rollout.stats["sum_rewards"])
    mean_test_rewards.append(eval_rollout.stats["sum_rewards"])
    if i % 50 == 0:
        print(f"Train: {np.mean(mean_train_rewards[-20:])} Test: {np.mean(mean_test_rewards[-20:])}")

  if not isinstance(terminated, (bool, np.bool8)):
  0%|        | 1/20000 [00:01<5:36:35,  1.01s/it]

Train: -283.38802796338604 Test: -351.7716423374975


  0%|       | 51/20000 [00:51<5:30:49,  1.01it/s]

Train: -220.67450572555362 Test: -183.13972498388426


  1%|      | 101/20000 [01:42<5:40:51,  1.03s/it]

Train: -328.91074254282694 Test: -327.82695803002645


  1%|      | 151/20000 [02:33<5:39:12,  1.03s/it]

Train: -252.60447238373345 Test: -268.64087931362303


  1%|      | 201/20000 [03:24<5:33:56,  1.01s/it]

Train: -229.0054707994876 Test: -275.4381407863875


  1%|      | 251/20000 [04:15<5:40:05,  1.03s/it]

Train: -171.6786216463678 Test: -218.7152887048382


  2%|      | 301/20000 [05:07<5:35:38,  1.02s/it]

Train: -152.4680836772776 Test: -102.16467991276221


  2%|      | 351/20000 [05:57<5:32:06,  1.01s/it]

Train: -200.76129927960625 Test: -160.37237612115808


  2%|      | 401/20000 [06:48<5:29:04,  1.01s/it]

Train: -176.9245750354272 Test: -125.3152656422795


  2%|▏     | 451/20000 [07:39<5:28:50,  1.01s/it]

Train: -233.17718569486752 Test: -191.8921308103421


  3%|▏     | 501/20000 [08:30<5:25:57,  1.00s/it]

Train: -136.61746023884405 Test: -100.74066402507184


  3%|▏     | 551/20000 [09:20<5:26:38,  1.01s/it]

Train: -160.1869261930952 Test: -140.96230036830258


  3%|▏     | 601/20000 [10:12<5:35:10,  1.04s/it]

Train: -37.953300438924465 Test: -39.53791004235152


  3%|▏     | 651/20000 [11:03<5:37:50,  1.05s/it]

Train: 47.472094859206905 Test: 46.803826162979576


  4%|▏     | 701/20000 [11:56<5:34:54,  1.04s/it]

Train: 48.04196212242352 Test: 41.79740739833571


  4%|▏     | 751/20000 [12:51<6:00:41,  1.12s/it]

Train: 58.05068104048663 Test: 42.458287019921


  4%|▏     | 801/20000 [13:47<5:43:32,  1.07s/it]

Train: 29.283964569821602 Test: 26.705742052682847


  4%|▎     | 851/20000 [14:40<5:52:05,  1.10s/it]

Train: 72.81296918228631 Test: 65.45512327715426


  5%|▎     | 901/20000 [15:35<5:47:18,  1.09s/it]

Train: 82.5053943720081 Test: 77.58127555058391


  5%|▎     | 951/20000 [16:28<5:48:17,  1.10s/it]

Train: 68.13808334017737 Test: 57.24776575481277


  5%|▎    | 1001/20000 [17:22<5:37:02,  1.06s/it]

Train: 94.12179593303318 Test: 91.85174944557367


  5%|▎    | 1051/20000 [18:15<5:33:46,  1.06s/it]

Train: 84.18363104518971 Test: 77.35691845089896


  6%|▎    | 1101/20000 [19:08<5:37:19,  1.07s/it]

Train: 30.830862013056183 Test: 37.79910470573002


  6%|▎    | 1151/20000 [20:00<5:28:46,  1.05s/it]

Train: 80.11213663613907 Test: 75.11123038778608


  6%|▎    | 1201/20000 [20:52<5:29:00,  1.05s/it]

Train: 80.25686902855259 Test: 72.32302746378653


  6%|▎    | 1251/20000 [21:45<5:25:01,  1.04s/it]

Train: 96.17489524653845 Test: 92.74830054998141


  7%|▎    | 1301/20000 [22:37<5:27:27,  1.05s/it]

Train: 102.65010306008408 Test: 99.47390418026507


  7%|▎    | 1351/20000 [23:30<5:28:15,  1.06s/it]

Train: 101.36607195906683 Test: 95.08595013763059


  7%|▎    | 1401/20000 [24:22<5:18:42,  1.03s/it]

Train: 103.15119740317922 Test: 108.8120442451584


  7%|▎    | 1451/20000 [25:16<5:55:14,  1.15s/it]

Train: 91.12271898815295 Test: 77.64129720953416


  8%|▍    | 1501/20000 [26:11<5:38:45,  1.10s/it]

Train: 64.06061968611424 Test: 66.94758163273802


  8%|▍    | 1551/20000 [27:03<5:27:36,  1.07s/it]

Train: 97.1541736469885 Test: 94.56967380660693


  8%|▍    | 1601/20000 [27:55<5:25:48,  1.06s/it]

Train: 106.68964699412534 Test: 99.8114868264649


  8%|▍    | 1651/20000 [28:48<5:35:28,  1.10s/it]

Train: 105.41449841450057 Test: 99.59567907052995


  9%|▍    | 1701/20000 [29:42<5:33:27,  1.09s/it]

Train: 88.57169737394683 Test: 97.37401649312261


  9%|▍    | 1751/20000 [30:36<5:37:31,  1.11s/it]

Train: 110.41773025592654 Test: 102.14071884415799


  9%|▍    | 1752/20000 [30:38<5:19:07,  1.05s/it]

KeyboardInterrupt

