In [3]:
from dataclasses import dataclass
import abc
from typing import Any, Union

import torch

@dataclass
class Rollout:
    x: int


In [1]:
from typing import Any, Dict, Optional  # noqa

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as td
from optimrl.optimizer import LossFunction, RLOptimizer


### Env

In [2]:
env = gym.make("LunarLander-v2")
observation_dims = env.observation_space.shape[-1]
action_dims = env.action_space.n

In [3]:
action_dims

4

### Policy

In [4]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

In [5]:
class PPOCategoricalPolicy(nn.Module):
    def __init__(
        self,
        obs_dims,
        act_dims
    ):
        super().__init__()
        self.obs_dims = obs_dims
        self.act_dims = act_dims

        self.actor = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, self.act_dims))
        )

        self.critic = nn.Sequential(
            layer_init(nn.Linear(self.obs_dims, 128)),
            nn.ReLU(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 1))
        )

    def forward(self, obs: torch.Tensor):
        logits = self.actor(obs)
        values = self.critic(obs)
        dist = td.Categorical(logits=logits)
        actions = dist.sample()
        log_probs = dist.log_prob(actions)
        return {"actions":actions, "values":values, "log_probs":log_probs, "dist":dist, "logits":logits}


In [6]:
initial, _ = env.reset()

In [7]:
initial.shape

(8,)

In [8]:
policy = PPOCategoricalPolicy(observation_dims, action_dims)

In [9]:
policy(torch.from_numpy(initial).unsqueeze(0))

{'actions': tensor([3]),
 'values': tensor([[-0.3831]], grad_fn=<AddmmBackward0>),
 'log_probs': tensor([-1.4580], grad_fn=<SqueezeBackward1>),
 'dist': Categorical(probs: torch.Size([1, 4]), logits: torch.Size([1, 4])),
 'logits': tensor([[0.2158, 0.0829, 0.1531, 0.0575]], grad_fn=<AddmmBackward0>)}

In [10]:
env.step(env.action_space.sample())

  if not isinstance(terminated, (bool, np.bool8)):


(array([ 0.00816231,  1.4143033 ,  0.4055499 ,  0.08813022, -0.01033811,
        -0.11043775,  0.        ,  0.        ], dtype=float32),
 0.06435684814009052,
 False,
 False,
 {})

### Rollout

In [27]:
envs = gym.vector.SyncVectorEnv(
    [lambda: gym.make("LunarLander-v2") for i in range(4)],
)

In [28]:
state, _ = envs.reset()

In [31]:
envs.single_action_space.shape

()

In [41]:
envs.num_envs

4

In [84]:
def ppo_rollout(envs, policy, num_steps: int, seed: int, device):
    num_envs = envs.num_envs
    obs = np.zeros((num_steps, num_envs) + envs.single_observation_space.shape)
    actions = np.zeros((num_steps, num_envs) + envs.single_action_space.shape)
    logprobs = np.zeros((num_steps, num_envs))
    rewards = np.zeros((num_steps, num_envs))
    dones = np.zeros((num_steps, num_envs))
    values = np.zeros((num_steps, num_envs))

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    next_obs, _ = envs.reset(seed=seed)
    next_obs = next_obs
    next_done = np.zeros(num_envs)

    for step in range(num_steps):
        global_step += num_envs
        obs[step] = next_obs
        dones[step] = next_done

        # ALGO LOGIC: action logic
        out = policy(torch.Tensor(next_obs).to(device))
        values[step] = out["values"].squeeze().detach().cpu().numpy()
        action = out["actions"].detach().cpu().numpy()
        actions[step] = action
        logprobs[step] = out["log_probs"].squeeze().detach().cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, reward, terminations, truncations, infos = envs.step(action)
        next_done = np.logical_or(terminations, truncations)
        rewards[step] = reward

    return obs, actions, logprobs, rewards, dones, values

In [85]:
envs = gym.vector.SyncVectorEnv(
    [lambda: gym.make("LunarLander-v2") for i in range(4)],
)

In [86]:
device = torch.device("cpu")
obs, actions, logprobs, rewards, dones, values = ppo_rollout(envs, policy, 10000, 1, device=device)

In [None]:
obs = np.zeros((num_steps, num_envs) + envs.single_observation_space.shape)
actions = np.zeros((num_steps, num_envs) + envs.single_action_space.shape)
logprobs = np.zeros((num_steps, num_envs))
rewards = np.zeros((num_steps, num_envs))
dones = np.zeros((num_steps, num_envs))
values = np.zeros((num_steps, num_envs))

In [11]:
def ppo_rollout(
    policy,
    device
):
    SEED = None
    env = gym.make("LunarLander-v2")
    max_steps = 10000
    done = False
    observations, actions, rewards, logits, log_probs, values, terminals = (
        [],
        [],
        [],
        [],
        [],
        [],
        [],
    )
    count = 0
    observation, _ = env.reset(seed=SEED)
    with torch.no_grad():
        while not done:
            obs = torch.from_numpy(observation).unsqueeze(0).to(device)
            out = policy(obs)
            action = out["actions"].item()
            next_observation, reward, done, truncated, info = env.step(action)

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            logits.append(out["logits"].squeeze().detach().cpu().numpy())
            log_probs.append(out["log_probs"].squeeze().detach().cpu().numpy())
            values.append(out["values"].squeeze().detach().cpu().numpy())
            terminals.append(done)

            observation = next_observation
            if count == max_steps:
                done = True
            count += 1
    env.close()

    return {
        "observations": np.array(observations),
        "actions": np.array(actions),
        "rewards": np.array(rewards),
        "log_probs": np.array(log_probs),
        "logits":np.array(logits),
        "values": np.array(values),
        "terminals": np.array(terminals),
    }



In [12]:
rollout = ppo_rollout(policy, torch.device("cpu"))

### PPO Loss

In [13]:
class PPOLossFunction(LossFunction):
    def __init__(
        self,
        vf_coef: float = 1.0,
        entropy_weight: float = 0.001,
        gamma: float = 0.99,
        clip_ratio: float = 0.2,
        norm_returns: bool = True,
        norm_advantages: bool = True
    ):
        self.vf_coef = vf_coef
        self.entropy_weight = entropy_weight
        self.gamma = gamma
        self.clip_ratio = clip_ratio
        self.norm_returns = norm_returns
        self.norm_advantages = norm_advantages

    def discount_rewards(self, rews: torch.Tensor, normalize=False) -> torch.Tensor:
        n = rews.shape[-1]
        rtgs = torch.zeros_like(rews)
        for i in reversed(range(n)):
            rtgs[:, i] = rews[:, i] + self.gamma * (rtgs[:, i + 1] if i + 1 < n else 0)
        if normalize:
            rtgs = (rtgs - rtgs.mean(dim=-1)) / rtgs.std(dim=-1)
        return rtgs

    def actor_loss(
        self,
        log_probs: torch.Tensor,
        old_logprobs: torch.Tensor,
        advantages: torch.Tensor,
    ) -> torch.Tensor:
        ratio = torch.exp(log_probs - old_logprobs.detach())
        assert tuple(ratio.size()) == tuple(advantages.size())
        surr1 = ratio * advantages
        surr2 = (
            torch.clamp(ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio)
            * advantages
        )
        loss_pi = -torch.min(surr1, surr2)
        return loss_pi

    def value_loss(self, values: torch.Tensor, returns: torch.Tensor) -> torch.Tensor:
        assert tuple(values.squeeze().size()) == tuple(returns.squeeze().size())
        return F.mse_loss(returns.squeeze(), values.squeeze())

    def __call__(
        self,
        log_probs: torch.Tensor,
        dist,
        old_log_probs: torch.Tensor,
        rewards: torch.Tensor,
        values: torch.Tensor,
        terminals: Optional[torch.Tensor] = None,
    ):
        returns = self.discount_rewards(rewards, normalize=self.norm_returns)
        advantages = returns - values
        if self.norm_advantages:
            advantages = (advantages - advantages.mean(dim=-1)) / advantages.std(dim=-1)
        actor_loss = self.actor_loss(log_probs, old_log_probs, advantages).mean()
        value_loss = self.value_loss(values, returns).mean()
        entropy_loss = dist.entropy().sum(1)
        loss = (
            actor_loss + self.vf_coef * value_loss - self.entropy_weight * entropy_loss
        )
        return loss.mean()


In [14]:
env = gym.make("LunarLander-v2")

In [15]:
loss_fn = PPOLossFunction(entropy_weight=0.01, norm_advantages=False)

In [16]:
def prepare_inputs(
    rollouts,
    device
):
    observations = torch.stack([torch.from_numpy(r["observations"]) for r in rollouts]).to(device)
    actions = torch.stack([torch.from_numpy(r["actions"]) for r in rollouts]).to(device)
    rewards = torch.stack([torch.from_numpy(r["rewards"]) for r in rollouts]).to(device)
    terminals = torch.stack([torch.from_numpy(r["terminals"]) for r in rollouts]).to(device)
    log_probs = torch.stack([torch.from_numpy(r["log_probs"]) for r in rollouts]).to(device)
    logits = torch.stack([torch.from_numpy(r["logits"]) for r in rollouts]).to(device)
    values = torch.stack([torch.from_numpy(r["values"]) for r in rollouts]).to(device)
    dist = td.Categorical(logits=logits)
    return {
        "observations":observations,
        "actions":actions,
        "rewards":rewards,
        "terminals":terminals,
        "log_probs":log_probs,
        "logits":logits,
        "dist":dist,
        "values":values,
    }


In [17]:
device = torch.device("cpu")

In [18]:
prepared_inputs = prepare_inputs([rollout], device=device)

In [19]:
prepared_inputs["log_probs"].shape

torch.Size([1, 87])

In [20]:
out = policy(prepared_inputs["observations"].squeeze())

In [21]:
out["log_probs"].unsqueeze(0).shape

torch.Size([1, 87])

In [22]:
loss = loss_fn(
    log_probs=out["log_probs"].unsqueeze(0),
    dist=prepared_inputs["dist"],
    old_log_probs=prepared_inputs["log_probs"],
    rewards=prepared_inputs["rewards"],
    values=prepared_inputs["values"],
    terminals=prepared_inputs["terminals"]
)


### Optimize step

In [23]:
import torch.optim as optim

class PPOOptimizer:
    def __init__(
        self,
        policy,
        loss_fn,
        pi_lr: float = 0.0005,
        n_updates: int = 4
    ):
        self.policy = policy
        self.loss_fn = loss_fn
        self.pi_lr = pi_lr
        self.n_updates = n_updates
        self.optimizer = optim.Adam(self.policy.parameters(), lr=self.pi_lr)

    def update(
        self,
        rollouts,
        device,
    ):
        prepared_inputs = prepare_inputs(rollouts, device=device)
        for i in range(self.n_updates):
            self.optimizer.zero_grad()
            out = self.policy(prepared_inputs["observations"].squeeze())
            loss = self.loss_fn(
                log_probs=out["log_probs"].unsqueeze(0),
                dist=prepared_inputs["dist"],
                old_log_probs=prepared_inputs["log_probs"],
                rewards=prepared_inputs["rewards"],
                values=prepared_inputs["values"],
                terminals=prepared_inputs["terminals"]
            )
            loss.backward()
            self.optimizer.step()
        return loss.item()


In [24]:
loss_fn = PPOLossFunction(entropy_weight=0.01, norm_advantages=False)
optimizer = PPOOptimizer(
    policy=policy,
    loss_fn=loss_fn,
    pi_lr=0.002,
    n_updates=4
)

In [25]:
from tqdm import tqdm

bar = tqdm(np.arange(50000))
device = torch.device("cpu")

rewards = []
for i in bar:
    with torch.no_grad():
        rollout = ppo_rollout(policy, device)
    sum_reward = np.sum(rollout["rewards"])
    loss = optimizer.update([rollout], device)
    rewards.append(sum_reward)
    # bar.set_description(f"Loss: {loss}, Sum reward: {np.mean(rewards[-20:])}")
    if i % 500 == 0:
        print(f"Loss: {loss}, Sum reward: {np.mean(rewards[-25:])}")

  if not isinstance(terminated, (bool, np.bool8)):
  0%|                               | 9/50000 [00:00<20:05, 41.46it/s]

Loss: -0.9123230576515198, Sum reward: -121.96433337664078


  1%|▎                            | 505/50000 [00:12<21:42, 38.00it/s]

Loss: 0.021598339080810547, Sum reward: -204.8033545896234


  2%|▌                           | 1008/50000 [01:16<20:51, 39.13it/s]

Loss: -0.4012258052825928, Sum reward: -183.74630516808452


  3%|▊                           | 1506/50000 [01:28<20:03, 40.30it/s]

Loss: -0.688403844833374, Sum reward: -199.08353628364716


  4%|█                           | 2008/50000 [01:41<20:44, 38.57it/s]

Loss: -0.04335963726043701, Sum reward: -202.9304515309879


  5%|█▍                          | 2508/50000 [01:54<21:03, 37.58it/s]

Loss: -0.4709970951080322, Sum reward: -181.32710764372013


  6%|█▋                          | 3004/50000 [02:49<19:39, 39.83it/s]

Loss: -0.03006112575531006, Sum reward: -178.67277722775285


  7%|█▉                          | 3507/50000 [03:02<19:39, 39.41it/s]

Loss: -0.16471195220947266, Sum reward: -209.54099177786821


  8%|██▏                         | 4007/50000 [03:14<18:18, 41.87it/s]

Loss: -0.7212000489234924, Sum reward: -179.5099133810155


  9%|██▌                         | 4507/50000 [04:12<18:12, 41.66it/s]

Loss: -0.7187682390213013, Sum reward: -145.33330618750836


 10%|██▊                         | 4997/50000 [04:25<19:36, 38.27it/s]

Loss: -0.2751656770706177, Sum reward: -217.40698286073314


 11%|███                         | 5509/50000 [05:21<17:44, 41.81it/s]

Loss: -0.29213380813598633, Sum reward: -178.85057108977242


 12%|███▎                        | 6006/50000 [05:34<19:10, 38.24it/s]

Loss: -0.9155663847923279, Sum reward: -194.15881861547305


 13%|███▋                        | 6509/50000 [05:47<17:36, 41.16it/s]

Loss: -0.5538540482521057, Sum reward: -192.42256082158644


 14%|███▉                        | 7008/50000 [06:41<19:18, 37.09it/s]

Loss: -0.04770141839981079, Sum reward: -163.59124793778173


 15%|████▏                       | 7505/50000 [06:55<19:52, 35.63it/s]

Loss: -0.007029712200164795, Sum reward: -188.26077372963923


 16%|████▍                       | 8007/50000 [07:19<18:20, 38.17it/s]

Loss: -0.9566903114318848, Sum reward: -159.8437578707421


 17%|████▊                       | 8509/50000 [07:32<18:03, 38.31it/s]

Loss: -0.47236037254333496, Sum reward: -152.440325300596


 18%|█████                       | 9007/50000 [09:53<23:35, 28.95it/s]

Loss: -0.38170188665390015, Sum reward: -189.53242978136825


 19%|█████▎                      | 9506/50000 [10:06<16:24, 41.14it/s]

Loss: -0.18168425559997559, Sum reward: -172.39094077539124


 20%|█████▍                     | 10008/50000 [10:19<16:30, 40.37it/s]

Loss: -0.5760697722434998, Sum reward: -173.48255802675712


 21%|█████▋                     | 10507/50000 [12:24<18:39, 35.28it/s]

Loss: -0.25595808029174805, Sum reward: -178.39891476233103


 22%|█████▉                     | 11006/50000 [13:22<16:11, 40.15it/s]

Loss: -0.6077889800071716, Sum reward: -153.15266437739535


 23%|██████▏                    | 11506/50000 [13:35<15:41, 40.87it/s]

Loss: 0.021827340126037598, Sum reward: -176.83745776366112


 24%|██████▍                    | 12006/50000 [13:48<16:39, 38.03it/s]

Loss: -0.8123911023139954, Sum reward: -197.06400232156386


 25%|██████▊                    | 12506/50000 [14:01<14:11, 44.04it/s]

Loss: -0.4749137759208679, Sum reward: -206.74661070133467


 26%|███████                    | 13007/50000 [14:15<16:32, 37.29it/s]

Loss: -0.4754785895347595, Sum reward: -175.43541736061474


 27%|███████▎                   | 13509/50000 [14:28<15:00, 40.50it/s]

Loss: -0.09537428617477417, Sum reward: -193.55826600545598


 28%|███████▌                   | 14009/50000 [14:41<15:11, 39.48it/s]

Loss: -0.3441488742828369, Sum reward: -205.76604493534845


 29%|███████▊                   | 14505/50000 [15:38<15:38, 37.83it/s]

Loss: 0.268878698348999, Sum reward: -204.77879589118882


 30%|████████                   | 15005/50000 [15:51<15:15, 38.24it/s]

Loss: -0.11684024333953857, Sum reward: -159.86590427408845


 31%|████████▎                  | 15506/50000 [16:47<14:13, 40.43it/s]

Loss: -0.009740829467773438, Sum reward: -169.72070354326752


 32%|████████▋                  | 16009/50000 [18:25<15:02, 37.66it/s]

Loss: -0.6441897749900818, Sum reward: -191.45646535155785


 33%|████████▉                  | 16506/50000 [18:38<15:54, 35.08it/s]

Loss: -1.0159235000610352, Sum reward: -187.47968234931076


 34%|█████████▏                 | 17008/50000 [18:56<13:32, 40.63it/s]

Loss: -0.5579879283905029, Sum reward: -177.09693163884765


 35%|█████████▍                 | 17509/50000 [20:41<13:03, 41.46it/s]

Loss: -0.17307913303375244, Sum reward: -166.4807481155379


 36%|█████████▋                 | 18008/50000 [20:55<14:48, 36.01it/s]

Loss: -0.6376703381538391, Sum reward: -173.85312752724357


 37%|█████████▉                 | 18507/50000 [21:08<14:20, 36.60it/s]

Loss: -0.8682670593261719, Sum reward: -167.03836165686317


 38%|██████████▎                | 19006/50000 [21:21<13:38, 37.86it/s]

Loss: -0.015137135982513428, Sum reward: -173.9607325789318


 39%|██████████▌                | 19506/50000 [23:15<14:40, 34.62it/s]

Loss: -0.06990623474121094, Sum reward: -171.63975707754855


 40%|██████████▊                | 20006/50000 [23:54<12:52, 38.81it/s]

Loss: -0.597529411315918, Sum reward: -170.09377775182796


 41%|███████████                | 20508/50000 [24:08<13:41, 35.89it/s]

Loss: -0.8028754591941833, Sum reward: -210.65928536224513


 42%|███████████▎               | 21007/50000 [24:21<13:00, 37.13it/s]

Loss: -0.08732020854949951, Sum reward: -229.07104938236458


 43%|███████████▌               | 21507/50000 [24:34<11:15, 42.20it/s]

Loss: -0.9078831672668457, Sum reward: -182.68645827793435


 44%|███████████▉               | 22008/50000 [24:47<12:04, 38.64it/s]

Loss: -0.6860682368278503, Sum reward: -178.71364660571047


 45%|████████████▏              | 22506/50000 [25:00<13:24, 34.17it/s]

Loss: -0.2521415948867798, Sum reward: -182.30642785244422


 46%|████████████▍              | 23006/50000 [25:13<11:27, 39.26it/s]

Loss: 0.005267143249511719, Sum reward: -212.3939948509723


 47%|████████████▋              | 23505/50000 [25:26<11:39, 37.86it/s]

Loss: 0.06499230861663818, Sum reward: -169.00638867388898


 48%|████████████▉              | 24009/50000 [26:23<10:31, 41.14it/s]

Loss: -0.668048620223999, Sum reward: -200.58149547614457


 49%|█████████████▏             | 24506/50000 [26:36<10:23, 40.90it/s]

Loss: -0.5781775116920471, Sum reward: -205.68720948755237


 50%|█████████████▌             | 25008/50000 [26:48<10:20, 40.29it/s]

Loss: 0.07474678754806519, Sum reward: -152.88325408055786


 51%|█████████████▊             | 25508/50000 [27:41<11:16, 36.19it/s]

Loss: -0.5879594683647156, Sum reward: -191.68500274722703


 52%|██████████████             | 26007/50000 [29:35<10:13, 39.08it/s]

Loss: -0.5767504572868347, Sum reward: -188.10948604292327


 53%|██████████████▎            | 26508/50000 [30:30<09:10, 42.69it/s]

Loss: -0.044829487800598145, Sum reward: -181.1037499573044


 54%|██████████████▌            | 27007/50000 [32:12<08:57, 42.79it/s]

Loss: -0.01955169439315796, Sum reward: -195.84400316258584


 55%|██████████████▊            | 27507/50000 [33:33<09:37, 38.95it/s]

Loss: -0.20869207382202148, Sum reward: -192.7821516579379


 56%|███████████████            | 28006/50000 [33:46<08:51, 41.35it/s]

Loss: -0.5874777436256409, Sum reward: -184.97362269385525


 57%|███████████████▍           | 28506/50000 [33:59<08:48, 40.67it/s]

Loss: 0.22786706686019897, Sum reward: -183.3198141435553


 58%|███████████████▋           | 29004/50000 [35:04<46:01,  7.60it/s]

Loss: -0.04240858554840088, Sum reward: -201.64007059432902


 59%|███████████████▉           | 29507/50000 [35:58<09:22, 36.42it/s]

Loss: -0.4824875593185425, Sum reward: -198.01681806104725


 60%|████████████████▏          | 30006/50000 [36:12<09:09, 36.40it/s]

Loss: -0.8718962073326111, Sum reward: -219.0751360548372


 61%|████████████████▍          | 30506/50000 [36:25<08:07, 39.98it/s]

Loss: -0.6969972252845764, Sum reward: -159.35739241447186


 62%|████████████████▋          | 31007/50000 [37:36<07:41, 41.14it/s]

Loss: 0.16172856092453003, Sum reward: -166.91438535514126


 63%|█████████████████          | 31506/50000 [37:49<08:15, 37.31it/s]

Loss: -0.08011507987976074, Sum reward: -184.67959889922415


 64%|█████████████████▎         | 32008/50000 [38:43<06:52, 43.66it/s]

Loss: -0.6458133459091187, Sum reward: -141.825443018448


 65%|█████████████████▌         | 32508/50000 [38:56<07:50, 37.17it/s]

Loss: -0.21897780895233154, Sum reward: -207.2308283669893


 66%|█████████████████▊         | 33007/50000 [40:01<07:02, 40.22it/s]

Loss: 0.06228804588317871, Sum reward: -164.1968520341123


 67%|██████████████████         | 33506/50000 [40:14<07:12, 38.12it/s]

Loss: -0.22419899702072144, Sum reward: -150.89828819229638


 68%|██████████████████▎        | 34007/50000 [41:54<07:43, 34.53it/s]

Loss: -0.546815812587738, Sum reward: -208.886424612692


 69%|██████████████████▋        | 34507/50000 [42:50<06:23, 40.37it/s]

Loss: -0.4485434889793396, Sum reward: -210.72312600143061


 70%|██████████████████▉        | 35006/50000 [43:03<06:26, 38.79it/s]

Loss: -0.4707925319671631, Sum reward: -164.57747046801381


 71%|███████████████████▏       | 35505/50000 [43:58<06:33, 36.79it/s]

Loss: -0.39556485414505005, Sum reward: -218.3399202483202


 72%|███████████████████▍       | 36005/50000 [44:11<05:49, 40.06it/s]

Loss: -0.42740529775619507, Sum reward: -170.68179750361304


 73%|███████████████████▋       | 36506/50000 [44:24<06:06, 36.81it/s]

Loss: -0.4497363567352295, Sum reward: -169.9811581160207


 74%|███████████████████▉       | 37008/50000 [44:37<05:34, 38.79it/s]

Loss: -0.9926982522010803, Sum reward: -166.5406638296108


 75%|████████████████████▎      | 37507/50000 [44:50<05:31, 37.68it/s]

Loss: 0.035171449184417725, Sum reward: -201.19791849110368


 76%|████████████████████▌      | 38007/50000 [45:45<04:57, 40.33it/s]

Loss: -0.5881744027137756, Sum reward: -196.3165103511095


 77%|████████████████████▊      | 38507/50000 [45:58<04:48, 39.77it/s]

Loss: 0.11915385723114014, Sum reward: -191.97146959234965


 78%|█████████████████████      | 39010/50000 [46:11<04:06, 44.67it/s]

Loss: -0.5734822750091553, Sum reward: -184.6337318738487


 79%|█████████████████████▎     | 39507/50000 [47:05<04:13, 41.45it/s]

Loss: -0.3181988596916199, Sum reward: -157.30695473330974


 80%|█████████████████████▌     | 40009/50000 [47:18<04:11, 39.75it/s]

Loss: -0.09143698215484619, Sum reward: -213.11822522956402


 81%|█████████████████████▊     | 40508/50000 [47:30<04:18, 36.69it/s]

Loss: -0.9784424304962158, Sum reward: -228.25272118281143


 82%|██████████████████████▏    | 41007/50000 [48:28<03:41, 40.62it/s]

Loss: -0.8008904457092285, Sum reward: -188.56613374319443


 83%|██████████████████████▍    | 41508/50000 [50:06<03:32, 40.05it/s]

Loss: -0.13125699758529663, Sum reward: -182.91543329069384


 84%|██████████████████████▋    | 42009/50000 [50:19<03:21, 39.59it/s]

Loss: -0.7567464709281921, Sum reward: -196.69618719660866


 85%|██████████████████████▉    | 42507/50000 [52:15<04:44, 26.36it/s]

Loss: -0.6271012425422668, Sum reward: -204.25379922211053


 86%|███████████████████████▏   | 43008/50000 [52:28<02:58, 39.11it/s]

Loss: -0.6759347319602966, Sum reward: -211.008525728671


 87%|███████████████████████▍   | 43506/50000 [52:41<02:56, 36.70it/s]

Loss: -0.2601165771484375, Sum reward: -151.04755494880916


 88%|███████████████████████▊   | 44007/50000 [53:35<02:22, 42.17it/s]

Loss: -0.4664231538772583, Sum reward: -210.25096459556616


 89%|████████████████████████   | 44508/50000 [55:19<02:09, 42.32it/s]

Loss: 0.17987185716629028, Sum reward: -173.06133549554414


 90%|████████████████████████▎  | 45006/50000 [55:31<02:04, 40.09it/s]

Loss: -0.7794811129570007, Sum reward: -196.5146857917186


 91%|████████████████████████▌  | 45505/50000 [55:43<01:38, 45.48it/s]

Loss: -0.6166368722915649, Sum reward: -170.39161682557847


 92%|████████████████████████▊  | 46005/50000 [55:54<01:34, 42.09it/s]

Loss: -0.3479241728782654, Sum reward: -172.44371940459519


 93%|█████████████████████████  | 46510/50000 [56:44<01:17, 44.75it/s]

Loss: -0.471635639667511, Sum reward: -189.4804783209033


 94%|█████████████████████████▍ | 47007/50000 [58:28<01:04, 46.52it/s]

Loss: -0.8474149703979492, Sum reward: -195.81645173561176


 95%|███████████████████████▊ | 47509/50000 [1:48:20<01:00, 41.39it/s]

Loss: -0.6737468242645264, Sum reward: -204.91964127972315


 96%|████████████████████████ | 48008/50000 [1:48:32<00:45, 44.09it/s]

Loss: -0.3036153316497803, Sum reward: -210.04335067274292


 97%|████████████████████████▎| 48508/50000 [1:49:55<26:35,  1.07s/it]

Loss: -0.1140822172164917, Sum reward: -215.0054140504317


 98%|████████████████████████▌| 49008/50000 [1:50:44<12:59,  1.27it/s]

Loss: -0.5504346489906311, Sum reward: -209.754436534842


 99%|████████████████████████▊| 49509/50000 [1:51:33<00:11, 41.42it/s]

Loss: -0.07728588581085205, Sum reward: -142.20115018727003


100%|█████████████████████████| 50000/50000 [1:51:45<00:00,  7.46it/s]
