In [1]:
from __future__ import annotations
import numpy as np
from typing import List, Dict, Any
from dataclasses import dataclass
import torch.nn as nn
import gym
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
from tqdm import tqdm
import random

In [2]:
seed = 1

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x10ba54970>

In [3]:
@dataclass
class EnvStep:
    observation: Any
    action: Any
    reward: Any
    done: Any
    info: Dict[str, Any]
    policy_output: Dict[str, Any]
    step: int

@dataclass
class Rollout:
    steps: List[EnvStep]
    last_obs: Any
    last_done: Any
    num_envs: int
    _stats: Any = None
    episodic_return: Any = None

    @property
    def stats(self):
        if self._stats is None:
            sum_rewards = np.mean(np.sum(np.array([s.reward for s in self.steps]), 0))
            self._stats = {
                "sum_rewards":sum_rewards
            }
        return self._stats

    def __len__(self) -> int:
        return len(self.steps)

    @classmethod
    def rollout(cls, envs, policy, num_steps: int, seed: int, evaluate: bool = False) -> Rollout:
        num_envs = envs.num_envs
        obs, infos = envs.reset(seed=seed)
        done = np.zeros((num_envs))
        env_steps = []
        policy_out = None
        episodic_return = None
        for step in range(num_steps):
            with torch.no_grad():
                action, policy_out = policy.act(obs, policy_out)
            next_obs, reward, terminations, truncations, infos = envs.step(action)
            done = np.logical_or(terminations, truncations)
            env_step = EnvStep(
                observation=obs,
                action=action,
                reward=reward,
                done=done,
                info=infos,
                policy_output=policy_out,
                step=step
            )
            obs = next_obs
            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        episodic_return = info['episode']['r']
            env_steps.append(env_step)
            if evaluate:
                if done:
                    break
        return cls(
            steps=env_steps,
            last_obs=next_obs,
            last_done=done,
            num_envs=num_envs,
            episodic_return=episodic_return
        )


In [4]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class PPOCategoricalPolicy(nn.Module):
    def __init__(
        self,
        obs_dims,
        act_dims
    ):
        super().__init__()
        self.obs_dims = obs_dims
        self.act_dims = act_dims

        self.critic = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 1)),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, 128)),
            nn.ReLU(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, self.act_dims)),
        )

    def forward(self, obs: torch.Tensor, actions = None):
        logits = self.actor(obs)
        values = self.critic(obs)
        dist = td.Categorical(logits=logits)
        if actions is None:
            actions = dist.sample()
        entropy = dist.entropy()
        log_probs = dist.log_prob(actions)
        return {"actions":actions, "values":values, "log_probs":log_probs, "dist":dist, "logits":logits, "entropy":entropy}

    def act(self, obs: torch.Tensor, prev_output = {}):
        with torch.no_grad():
            out = self.forward(torch.from_numpy(obs))
            return out["actions"].detach().cpu().numpy(), out


In [5]:
def make_env(env_name):
    env =  gym.make(env_name)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env

In [6]:
def convert_to_torch(rollout, device):
    steps = rollout.steps
    num_envs = rollout.num_envs
    num_steps = len(rollout)
    # num_envs x ...
    obs_shape = steps[0].observation.shape[1:]

    if len(steps[0].action.shape) <= 1:
        action_shape = ()
    else:
        action_shape = steps[0].action.shape[1:]

    obs = torch.zeros((num_steps, num_envs) + obs_shape).to(device)
    actions = torch.zeros((num_steps, num_envs) + action_shape).to(device)
    rewards = torch.zeros((num_steps, num_envs)).to(device)
    dones = torch.zeros((num_steps, num_envs)).to(device)

    infos = []
    policy_outs = {}

    for i, step in enumerate(steps):
        obs[i, :] = torch.from_numpy(step.observation).to(device)
        actions[i] = torch.Tensor(step.action).to(device)
        rewards[i] = torch.from_numpy(step.reward).to(device)
        dones[i] = torch.from_numpy(step.done).to(device)
        infos.append(step.info)

        for k in step.policy_output:
            if k not in policy_outs:
                policy_outs[k] = []
            policy_outs[k].append(step.policy_output[k])

    for k in policy_outs:
        if isinstance(policy_outs[k], torch.Tensor):
            policy_outs[k] = torch.stack(policy_outs[k]).to(device)
    
    last_obs = torch.Tensor(rollout.last_obs).to(device)
    last_done = torch.Tensor(rollout.last_done).to(device)

    return {
        "obs":obs,
        "actions":actions,
        "rewards":rewards,
        "dones":dones,
        "infos":infos,
        "policy_outs":policy_outs,
        "last_obs":last_obs,
        "last_done":last_done
    }

In [7]:
def get_returns_advantages(
    rewards: torch.Tensor,
    values: torch.Tensor,
    dones: torch.Tensor,
    gamma: float = 0.99,
    normalize_returns: bool = False,
    normalize_advantages: bool = True
):
    
    with torch.no_grad():
        returns = torch.zeros_like(rewards)
        num_steps = returns.shape[0]

        for t in reversed(range(num_steps)):
            if t == num_steps - 1:
                R = torch.zeros_like(rewards[t])
            else:
                R = returns[t+1]
            returns[t] = rewards[t] + (1.0-dones[t])*R * gamma

        if normalize_returns:
            # normalize over num_steps
            returns = (returns - returns.mean(dim=0, keepdim=True)) / returns.std(dim=0, keepdim=True)

        advantages = returns - values.detach()
        if normalize_advantages:
            advantages = (advantages - advantages.mean(dim=0, keepdim=True)) / advantages.std(dim=0, keepdim=True)
        return returns, advantages

class PPOLossFunction:
    def __init__(
        self,
        vf_coef: float = 0.5,
        ent_coef: float = 0.001,
        clip_ratio: float = 0.2,
        clip_vloss: bool = True,
    ):
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef
        self.clip_ratio = clip_ratio
        self.clip_vloss = clip_vloss

    def __call__(
        self,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        values: Optional[torch.Tensor],
        returns: Optional[torch.Tensor],
        advantages: Optional[torch.Tensor],
        old_values: Optional[torch.Tensor] = None,
        entropy: Optional[torch.Tensor] = None,
    ):
        logratio = log_probs - old_log_probs
        ratio = logratio.exp()

        # pgloss
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio)
        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        # value loss
        if self.clip_vloss and old_values is not None:
            v_loss_unclipped = (values - returns) ** 2
            v_clipped = old_values.detach() + torch.clamp(
                values - old_values.detach(),
                -self.clip_ratio,
                self.clip_ratio,
            )
            v_loss_clipped = (v_clipped - returns) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()
        else:
            v_loss = 0.5 * ((values - returns) ** 2).mean()

        if entropy is not None:
            entropy_loss = entropy.mean()
        else:
            entropy = 0.0
        loss = pg_loss - self.ent_coef * entropy_loss + self.vf_coef * v_loss
        return loss, {"pg_loss":pg_loss.item(), "entropy_loss":entropy_loss.item(), "v_loss":v_loss.item()}

In [8]:
device = torch.device("cpu")
envs = gym.vector.SyncVectorEnv(
    [lambda: make_env("LunarLander-v2") for i in range(1)],
)
policy = PPOCategoricalPolicy(8, 4)
rollout = Rollout.rollout(envs, policy, 256, 1, evaluate=True)
print(len(rollout))

116


  if not isinstance(terminated, (bool, np.bool8)):


In [9]:
rollouts = convert_to_torch(rollout, device)
rewards = rollouts["rewards"].view(len(rollout), 1)
np_rewards = rewards.detach().cpu().numpy()

dones = rollouts["dones"]
old_values = torch.stack(rollouts["policy_outs"]["values"]).detach().view(len(rollout), 1)

old_log_probs = torch.stack(rollouts["policy_outs"]["log_probs"]).detach()
observations = rollouts["obs"]
actions = rollouts["actions"]

In [10]:
old_values.shape

torch.Size([116, 1])

In [11]:
def calculate_returns_original(rewards, discount_factor, normalize = False):
    
    returns = []
    R = 0
    
    for r in reversed(rewards):
        R = r + R * discount_factor
        returns.insert(0, R)
        
    returns = torch.stack(returns).squeeze()

    if normalize:
        returns = (returns - returns.mean(dim=0, keepdim=True)) / returns.std(dim=0, keepdim=True)

    return returns

def calculate_advantages_original(returns, values, normalize = True):

    advantages = returns - values
    
    if normalize:
        
        advantages = (advantages - advantages.mean(dim=0, keepdim=True)) / advantages.std(dim=0, keepdim=True)

    return advantages

In [12]:
returns, advantages = get_returns_advantages(
    rewards=rewards,
    values=old_values,
    dones=dones,
    gamma=0.99,
)
original_returns = calculate_returns_original(rewards.squeeze(), 0.99, False).squeeze()
original_adv = calculate_advantages_original(original_returns, old_values.squeeze()).squeeze()

In [13]:
print(torch.sum((returns.squeeze()-original_returns)))
print(torch.sum((advantages.squeeze()-original_adv)))

tensor(0.)
tensor(0.)


### Optimizer

In [14]:
convert = convert_to_torch(rollout, torch.device("cpu"))

In [15]:
old_log_probs = torch.stack(convert["policy_outs"]["log_probs"]).detach()

In [16]:
class PPOOptimizer:
    def __init__(
        self,
        policy,
        loss_fn,
        num_minibatches: int = 4,
        pi_lr: float = 0.0002,
        n_updates: int = 4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        max_grad_norm: float = 0.5,
        norm_returns: bool = False,
        norm_advantages: bool = True
    ):
        self.policy = policy
        self.loss_fn = loss_fn
        self.num_minibatches = num_minibatches
        self.pi_lr = pi_lr
        self.n_updates = n_updates
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.max_grad_norm = max_grad_norm
        self.norm_returns = norm_returns
        self.norm_advantages = norm_advantages

        # setup optimizer
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.pi_lr, eps=1e-5)

    def step(
        self,
        rollouts,
        device,
    ):
        num_envs = rollouts.num_envs
        num_steps = len(rollouts)
        batch_size = num_envs * num_steps
        minibatch_size = int(batch_size // self.num_minibatches)

        rollouts = convert_to_torch(rollouts, device)
        rewards = rollouts["rewards"].view(num_steps, num_envs).detach()
        np_rewards = rewards.detach().cpu().numpy()

        dones = rollouts["dones"].view(num_steps, num_envs).detach()
        old_values = torch.stack(rollouts["policy_outs"]["values"]).detach().view(num_steps, num_envs)

        old_log_probs = torch.stack(rollouts["policy_outs"]["log_probs"]).detach()
        observations = rollouts["obs"].detach()
        actions = rollouts["actions"].detach()

        with torch.no_grad():
            returns, advantages = get_returns_advantages(
                rewards=rewards,
                values=old_values,
                dones=dones,
                gamma=self.gamma,
                normalize_returns=self.norm_returns,
                normalize_advantages=self.norm_advantages
            )
            # flatten stuff

        rewards = rewards.view(-1)
        dones = dones.view(-1)
        old_values = old_values.view(-1)

        observations = torch.flatten(observations, 0, 1)
        old_log_probs = torch.flatten(old_log_probs, 0,1)
        actions = torch.flatten(actions, 0,1)

        returns = returns.view(batch_size,)
        advantages = advantages.view(batch_size,)


        b_inds = np.arange(batch_size)
        for _ in range(self.n_updates):
            np.random.shuffle(b_inds)
            for start in range(0, batch_size, minibatch_size):
                end = start + minibatch_size
                mb_inds = b_inds[start:end]

                self.optimizer.zero_grad()
                out = self.policy(observations[mb_inds], actions[mb_inds])
                log_probs = out["log_probs"].view(minibatch_size, -1).squeeze(dim=-1)
                entropy = out["entropy"].view(minibatch_size,)
                values = out["values"].view(minibatch_size,)
                loss, stats = self.loss_fn(
                    log_probs=log_probs,
                    old_log_probs=old_log_probs[mb_inds],
                    values=values,
                    old_values=old_values[mb_inds],
                    returns=returns[mb_inds],
                    advantages=advantages[mb_inds],
                    entropy=entropy
                )
                loss.backward()
                # nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.optimizer.step()
        return loss.item(), np_rewards, stats


In [17]:
policy(torch.zeros((4,8)))

{'actions': tensor([1, 0, 1, 1]),
 'values': tensor([[0.],
         [0.],
         [0.],
         [0.]], grad_fn=<AddmmBackward0>),
 'log_probs': tensor([-1.3863, -1.3863, -1.3863, -1.3863], grad_fn=<SqueezeBackward1>),
 'dist': Categorical(probs: torch.Size([4, 4]), logits: torch.Size([4, 4])),
 'logits': tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]], grad_fn=<AddmmBackward0>),
 'entropy': tensor([1.3863, 1.3863, 1.3863, 1.3863], grad_fn=<NegBackward0>)}

In [18]:
# policy = PPOCategoricalPolicy(8, 4)
# loss_fn = PPOLossFunction(clip_ratio=0.2, ent_coef=0.001, vf_coef=0.5)
# optimizer = PPOOptimizer(policy, loss_fn, pi_lr=2.5e-4, n_updates=10)

In [19]:
# envs = gym.vector.SyncVectorEnv(
#     [lambda: gym.make("LunarLander-v2") for i in range(1)],
# )
# rollout = Rollout.rollout(envs, policy, 256, 1, evaluate=False)
# print(len(rollout))

In [20]:
# device = torch.device("cpu")
# convert = convert_to_torch(rollout, device)
# loss, rewards, stats = optimizer.step(rollout, device)

In [None]:
train_envs = gym.vector.SyncVectorEnv(
    [lambda: gym.make("LunarLander-v2") for i in range(4)],
)
test_envs = gym.vector.SyncVectorEnv(
    [lambda: gym.make("LunarLander-v2") for i in range(1)],
)

bar = tqdm(np.arange(20000))
device = torch.device("cpu")

policy = PPOCategoricalPolicy(8, 4)
loss_fn = PPOLossFunction(clip_ratio=0.2, ent_coef=0.01, vf_coef=1.0)
optimizer = PPOOptimizer(policy, loss_fn, pi_lr=0.0002, n_updates=5, num_minibatches=1)

global_steps = 0
mean_train_rewards = []
mean_test_rewards = []
for i in bar:
    with torch.no_grad():
        rollout = Rollout.rollout(envs, policy, 256, seed, evaluate=False)
        eval_rollout = Rollout.rollout(test_envs, policy, 256, seed+1, evaluate=True)
    global_steps += rollout.num_envs*len(rollout)
    loss, rewards, stats = optimizer.step(rollout, device)
    mean_train_rewards.append(rollout.stats["sum_rewards"])
    mean_test_rewards.append(eval_rollout.stats["sum_rewards"])
    if i % 50 == 0:
        print(f"Train: {np.mean(mean_train_rewards[-20:])} Test: {np.mean(mean_test_rewards[-20:])}")
    # bar.set_description(f"Loss: {loss}, Sum reward: {np.mean(rewards[-20:])}")
    # bar.set_description(f"{global_steps} -- L: {loss} S: {np.mean(mean_rewards[-25:])} M: {max_rew}")

  if not isinstance(terminated, (bool, np.bool8)):
  0%|            | 3/20000 [00:00<1:20:51,  4.12it/s]

Train: -380.43291922377927 Test: -105.69929568387434


  0%|             | 52/20000 [00:05<31:01, 10.72it/s]

Train: -213.01842134329863 Test: -128.81345930204162


  1%|            | 102/20000 [00:10<36:48,  9.01it/s]

Train: -70.06668096896189 Test: -59.95181995586487


  1%|            | 152/20000 [00:17<41:09,  8.04it/s]

Train: -9.62809992148549 Test: 9.640140132074452


  1%|            | 202/20000 [00:24<56:19,  5.86it/s]

Train: 45.7800647398678 Test: 7.545780652048009


  1%|▏           | 253/20000 [00:32<39:49,  8.26it/s]

Train: 52.427269298417684 Test: 68.43331958323128


  2%|▏           | 302/20000 [00:38<36:32,  8.98it/s]

Train: 57.463573048848296 Test: 37.84676170400623


  2%|▏           | 352/20000 [00:44<53:51,  6.08it/s]

Train: 54.54485660174275 Test: 39.61533167077871


  2%|▏           | 402/20000 [00:52<52:18,  6.24it/s]

Train: 46.89263871209961 Test: 67.89653184160812


  2%|▏         | 452/20000 [01:01<1:03:46,  5.11it/s]

Train: 82.19050620095227 Test: 63.84792730128164


  3%|▎           | 502/20000 [01:11<57:04,  5.69it/s]

Train: 56.744948740242386 Test: 41.634315454950745


  3%|▎           | 552/20000 [01:19<56:29,  5.74it/s]

Train: 91.730189605878 Test: 68.17809923873928


  3%|▎           | 602/20000 [01:28<54:08,  5.97it/s]

Train: 68.55919064427371 Test: 92.29942112923788


  3%|▍           | 652/20000 [01:36<43:26,  7.42it/s]

Train: 113.20790381155487 Test: 85.23925180384803


  4%|▍           | 702/20000 [01:43<49:24,  6.51it/s]

Train: 116.77999487001657 Test: 61.963898917872484


  4%|▍           | 752/20000 [01:52<56:58,  5.63it/s]

Train: 106.72311073095173 Test: 77.02371375021043


  4%|▍           | 802/20000 [02:01<59:35,  5.37it/s]

Train: 107.74015088194116 Test: 89.48017189998922


  4%|▍           | 830/20000 [02:07<55:33,  5.75it/s]

In [None]:
len(eval_rollout)

In [None]:
rews = [s.done for s in eval_rollout.steps]

In [None]:
rews