In [None]:
from typing import Any, Dict  # noqa

import gym
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import pybulletgym

from ezrl.algorithms.ppo import PPOOptimizer
from ezrl.policy import ACPolicy

In [None]:
import torch.nn as nn
import torch.distributions as td

class AntACPolicy(ACPolicy):
    def __init__(self):
        super().__init__()
        self.input_dims = 28
        self.action_dims = 8

        self.policy_net = nn.Sequential(
            nn.Linear(28, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 8, bias=False)
        )

        self.critic_net = nn.Sequential(
            nn.Linear(28, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 1, bias=False)
        )

        log_std = -0.5 * np.ones(self.action_dims, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))

    def log_prob(self, dist: td.Distribution, actions: torch.Tensor):
        if isinstance(dist, td.Categorical):
            return dist.log_prob(actions)
        return dist.log_prob(actions).sum(axis=-1)

    def forward(self, obs: Any) -> Dict[str, Any]:
        mu = torch.tanh(self.policy_net(obs))
        std = torch.exp(self.log_std)
        dist = td.normal.Normal(mu, std)
        action = torch.clamp(dist.sample(), -1.0, 1.0)
        log_probs = self.log_prob(dist, action)
        return {"action":action, "dist":dist, "log_probs":log_probs}

    def critic(self, obs:Any):
        return self.critic_net(obs).squeeze()

    def act(self, obs: Any):
        out = self.forward(obs)
        return np.squeeze(out["action"].detach().cpu().numpy()), out

In [None]:
def ppo_rollout(
    policy: ACPolicy, env_name: str = None, env=None, env_creation_fn=None
) -> Dict[str, np.array]:
    if env_name is None and env is None:
        raise ValueError("env_name or env must be provided!")
    if env is None:
        if env_creation_fn is None:
            env_creation_fn = gym.make
        env = env_creation_fn(env_name)
    done = False
    observations, actions, rewards, log_probs, values = ([], [], [], [], [])
    observation = env.reset()
    with torch.no_grad():
        while not done:
            obs = torch.from_numpy(observation).to(policy.device)
            action, out = policy.act(obs)
            v = policy.critic(obs)
            next_observation, reward, done, info = env.step(action)

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            log_probs.append(out["log_probs"].detach().cpu().numpy())
            values.append(v.detach().cpu().numpy())

            observation = next_observation
    # env.close()
    return {
        "observations": np.array(observations),
        "actions": np.array(actions),
        "rewards": np.array(rewards),
        "log_probs": np.array(log_probs),
        "values": np.array(values),
    }

In [None]:
policy = AntACPolicy()
device = torch.device('cuda')
policy = policy.to(device)

In [None]:
from ezrl.utils import get_tensorboard_logger

In [None]:
optimizer = PPOOptimizer(policy)

In [None]:
writer = get_tensorboard_logger("PPOOptimizer")

In [None]:
env = gym.make("AntPyBulletEnv-v0")

In [None]:
from tqdm import tqdm

bar = tqdm(np.arange(50000))

for i in bar:
    rollouts = optimizer.rollout(ppo_rollout, env=env)

    losses = []
    actor_losses = []
    value_losses = []
    rewards = []

    for r in rollouts:
        observations = torch.from_numpy(r["observations"]).float().to(policy.device)
        actions = torch.from_numpy(r["actions"]).float().to(policy.device)
        returns = torch.from_numpy(r["returns"]).float().to(policy.device)
        advantages = torch.from_numpy(r["advantages"]).float().to(policy.device)
        log_probs = torch.from_numpy(r["log_probs"]).float().to(policy.device)

        for _ in range(5): # train each episode for 5 iterations
            optimizer.zero_grad()
            loss, actor_loss, value_loss = optimizer.loss_fn(observations, actions, returns, advantages, log_probs)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 100.0)
            optimizer.step()

            grad_dict = {}
            for n, W in policy.named_parameters():
                if W.grad is not None:
                    grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

        rewards.append(np.sum(r["rewards"]))
        losses.append(loss.item())
        actor_losses.append(actor_loss.item())
        value_losses.append(value_loss.item())

    metrics_dict = {"loss":np.mean(losses), "actor_loss":np.mean(actor_losses), "value_loss":np.mean(value_losses), "sum_reward":np.mean(rewards), **grad_dict}

    for key in metrics_dict:
        writer.add_scalar(key, metrics_dict[key], i)


    bar.set_description("Loss: {}, Reward: {}".format(loss, np.mean(rewards)))