In [1]:
from typing import Any, Dict  # noqa

import gym
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim

from ezrl.algorithms.ppo import PPOOptimizer
from ezrl.policy import ACPolicy


def ppo_rollout(
    policy: ACPolicy, env_name: str = None, env=None, env_creation_fn=None
) -> Dict[str, np.array]:
    if env_name is None and env is None:
        raise ValueError("env_name or env must be provided!")
    if env is None:
        if env_creation_fn is None:
            env_creation_fn = gym.make
        env = env_creation_fn(env_name)
    done = False
    observations, actions, rewards, log_probs, values = ([], [], [], [], [])
    observation = env.reset()
    with torch.no_grad():
        while not done:
            action, out = policy.act(torch.from_numpy(observation).to(policy.device))
            next_observation, reward, done, info = env.step(action)

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            log_probs.append(out["log_probs"].detach().cpu().numpy())
            values.append(out["values"].detach().cpu().numpy())

            observation = next_observation
    env.close()
    return {
        "observations": np.array(observations),
        "actions": np.array(actions),
        "rewards": np.array(rewards),
        "log_probs": np.array(log_probs),
        "values": np.array(values),
    }

In [2]:
import torch.nn as nn
import torch.distributions as td

class LunarLanderACPolicy(ACPolicy):
    def __init__(self):
        super().__init__()
        self.input_dims = 8
        self.output_dims = 4

        self.policy_net = nn.Sequential(
            nn.Linear(8, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 4, bias=False)
        )

        self.critic_net = nn.Sequential(
            nn.Linear(8, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 1, bias=False)
        )

    def forward(self, obs: Any) -> Dict[str, Any]:
        logits = self.policy_net(obs)
        dist = td.Categorical(logits=logits)
        action = dist.sample()
        log_probs = dist.log_prob(action)
        return {"action":action, "dist":dist, "log_probs":log_probs}

    def critic(self, obs:Any):
        return self.critic_net(obs).squeeze()

    def act(self, obs: Any):
        out = self.forward(obs)
        return out["action"].item(), out

In [3]:
policy = LunarLanderACPolicy()
device = torch.device('cuda')
policy = policy.to(device)

In [4]:
from ezrl.utils import get_tensorboard_logger

In [5]:
optimizer = PPOOptimizer(policy)

In [None]:
from tqdm import tqdm

bar = tqdm(np.arange(50000))

for i in bar:
    rollouts = optimizer.rollout(ppo_rollout, env_name = "LunarLander-v2")

    losses = []
    actor_losses = []
    value_losses = []
    rewards = []

    for r in rollouts:
        observations = torch.from_numpy(rollouts["observations"]).to(policy.device)
        actions = torch.from_numpy(rollouts["actions"]).to(policy.device)
        returns = torch.from_numpy(rollouts["returns"]).to(policy.device)
        advantages = torch.from_numpy(rollouts["advantages"]).to(policy.device)
        log_probs = torch.from_numpy(rollouts["log_probs"]).to(policy.device)

        for i in range(5): # train each episode for 5 iterations
            optimizer.zero_grad()
            loss = optimizer.loss_fn(observations, actions, returns, advantages, log_probs)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 10.0)
            optimizer.step()

            grad_dict = {}
            for n, W in policy.named_parameters():
                if W.grad is not None:
                    grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())



        metrics_dict = {"loss":avg_reward, "sum_reward":avg_reward, **grad_dict}

        for key in metrics_dict:
            writer.add_scalar(key, metrics_dict[key], i)


    bar.set_description("Loss: {}, Reward: {}".format(loss, avg_reward))