In [3]:
from collections import OrderedDict
from tqdm.auto import tqdm
from typing import Union

from IPython import display
from matplotlib import pyplot as plt
import numpy as np

from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchrl.collectors import SyncDataCollector
from torchrl.data import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import Compose, DoubleToFloat, ObservationNorm, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

import sys
sys.path.insert(0, r'/home/ztzhu/codes/rl-toolbox')
from rl_toolbox.utils.backend import get_device
from rl_toolbox.utils.model_utils import save_model
from rl_toolbox.utils.network_utils import mlp
from rl_toolbox.visualization.monitor import plot_loss

device = get_device()

In [4]:
def make_env(
    env_name="LunarLander-v2",
    seed=0,
    init_stats_param: Union[int, OrderedDict] = 1000,
    check_env=False,
    **env_cfg
):
    env = GymEnv(env_name, **env_cfg)
    env = TransformedEnv(
        env,
        Compose(
            ObservationNorm(in_keys=["observation"], standard_normal=True),
            DoubleToFloat(in_keys=["observation"]),
        ),
    )
    env.set_seed(seed)
    init_env(env, init_stats_param)

    if check_env:
        check_env_specs(env)

    return env


def init_env(env, init_stats_param=1000):
    t = None
    for _t in env.transform:
        if isinstance(_t, ObservationNorm):
            t = _t
            break
    if t is None:
        return

    if isinstance(init_stats_param, int):
        if not t.initialized:
            t.init_stats(num_iter=init_stats_param)
    else:
        assert isinstance(init_stats_param, OrderedDict)
        t.load_state_dict(init_stats_param)


def make_actor(hidden_sizes, env):
    sizes = (
        [env.observation_spec["observation"].shape[0]]
        + hidden_sizes
        + [env.action_spec.shape[0] * 2]  # loc and scale
    )
    actor = mlp(sizes, nn.Tanh, nn.Identity)
    actor.append(NormalParamExtractor())
    actor = TensorDictModule(actor, ["observation"], ["loc", "scale"])
    actor = ProbabilisticActor(
        actor,
        ["loc", "scale"],
        ["action"],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "min": env.action_spec.space.minimum,
            "max": env.action_spec.space.maximum,
        },
        return_log_prob=True,
    )
    return actor


def make_critic(hidden_sizes, env):
    sizes = [env.observation_spec["observation"].shape[0]] + hidden_sizes + [1]
    critic = mlp(sizes, nn.Tanh, nn.Identity)
    critic = ValueOperator(module=critic, in_keys=["observation"])
    return critic


def make_collector(env, actor, epochs, steps_per_epoch, max_steps_per_traj):
    collector = SyncDataCollector(
        env,
        actor,
        frames_per_batch=steps_per_epoch,
        total_frames=steps_per_epoch * epochs,
        max_frames_per_traj=max_steps_per_traj,
        reset_at_each_iter=True,
        split_trajs=False,
        device=device
    )
    return collector


def make_buf(steps_per_epoch):
    replay_buffer = ReplayBuffer(
        storage=LazyTensorStorage(steps_per_epoch, device=device), sampler=SamplerWithoutReplacement(),
    )
    return replay_buffer


def make_loss(actor, critic, gamma=0.99, lam=0.95):
    gae = GAE(gamma=gamma, lmbda=lam, value_network=critic, average_gae=True)
    return gae, ClipPPOLoss(actor, critic)


In [5]:
env = make_env(continuous=True, render_mode=None)
actor = make_actor([128, 128], env)
critic = make_critic([128, 128], env)
collector = make_collector(env, actor, 300, 4000, 1000)
buf = make_buf(4000)

env.to(device)
actor.to(device)
critic.to(device)

gae, loss_module = make_loss(actor, critic)
opt = Adam(
    [
        {"params": actor.parameters(), "lr": 0.0004},
        {"params": critic.parameters(), "lr": 0.001},
    ]
)
scheduler = CosineAnnealingLR(opt, 300, 0.0)

fig = plt.figure()
ax = fig.add_subplot(111)
losses = []
for epoch, data in tqdm(enumerate(collector, 1)):
    l = []
    for _ in range(80):
        with torch.no_grad():
            gae(data)
        buf.extend(data)

        opt.zero_grad()

        loss_vals = loss_module(data)
        loss = (
            loss_vals["loss_objective"]
            + loss_vals["loss_critic"]
            + loss_vals["loss_entropy"]
        )
        loss.backward()
        l.append(loss.cpu().item())

        opt.step()

    losses.append(np.mean(l))
    plot_loss(ax, epoch, losses)
    plt.show()

    if epoch % 50 ==0:
        save_model(r'/home/ztzhu/codes/rl-toolbox/rl_toolbox/models/testppo', epoch, {}, actor=actor.state_dict(), critic=critic.state_dict(), env=env.state_dict())



0it [00:00, ?it/s]

KeyboardInterrupt: 

In [6]:
for data in collector:
    a=data
    break

In [7]:
a

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4000, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([4000]), device=cuda:0, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([4000]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([4000, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        loc: Tensor(shape=torch.Size([4000, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([4000, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: Tensor(shape=torch.Size([4000, 8]), device=cuda:0, dtype=torch.float32, is_shared=True),
                reward: Tensor(shape=torch.Size([4000, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                step

In [8]:
gae(data)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4000, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
        advantage: Tensor(shape=torch.Size([4000, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([4000]), device=cuda:0, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([4000]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([4000, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        loc: Tensor(shape=torch.Size([4000, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([4000, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: Tensor(shape=torch.Size([4000, 8]), device=cuda:0, dtype=torch.float32, is_shared=True),
                reward: T

In [10]:
data.keys()

_TensorDictKeysView(['action', 'step_count', 'observation', 'truncated', 'done', 'next', 'loc', 'scale', 'sample_log_prob', 'collector', 'state_value', 'advantage', 'value_target'],
    include_nested=False,
    leaves_only=False)