In [1]:
import gym
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from typing import Tuple, Type

import torch.distributions as td
import torch
import torch.nn.functional as F
import abc

from ezrl.optimizer import RLOptimizer
from ezrl.policy import GymPolicy

import abc

### Dreamer Policy

In [2]:
from ezrl.algorithms.dreamer.components import RSSM, DiscountPredictor, RewardPredictor, RecurrentModel, RepresentationModel, TransitionPredictor
from ezrl.algorithms.dreamer.utils import get_convs, get_deconvs, NormalDistribution
from ezrl.algorithms.dreamer.world_model import WorldModel
from ezrl.algorithms.dreamer.optimizer import DreamerOptimizer
from ezrl.algorithms.dreamer.policy import DreamerPolicy

# RGB LunarLander

# Setup 

In [3]:
obs_encoding_dim = 64
hidden_dim = 64
latent_dim = 64
action_dim = 4

### World Model

In [4]:
net = nn.Sequential(
    nn.Conv2d(3, 32, 3, 2, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 16, 3, 2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 1, 3, 2, padding=1)
)

In [5]:
t = torch.zeros(1,3,64,64)

In [6]:
net(t).size()

torch.Size([1, 1, 8, 8])

In [7]:
class LunarLanderRGBObsEncoder(nn.Module):
    def __init__(
        self,
        obs_encoding_dim: int,
        obs_dim: int = 8
    ):
        super().__init__()
        self.obs_encoding_dim = obs_encoding_dim
        self.obs_dim = obs_dim

        self.net = get_convs(obs_dim, 1, 3, self.obs_encoding_dim)
        # self.net = nn.Sequential(
        #             nn.Conv2d(3, 32, 3, 2, padding=1),
        #             nn.ReLU(),
        #             nn.Conv2d(32, 16, 3, 2, padding=1),
        #             nn.ReLU(),
        #             nn.Conv2d(16, 1, 3, 2, padding=1)
        #         )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs).view(obs.size(0), -1)


class RGBObsDecoder(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int, obs_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.obs_dim = obs_dim
        self.input_channels = hidden_dim + latent_dim
        self.net = get_deconvs(1, obs_dim, self.input_channels, 6)
        self.distribution = NormalDistribution


    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1).view(hidden_state.size(0), self.input_channels, 1, 1)
        logits = self.net(inp)
        return self.distribution(logits)

In [8]:
recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

In [9]:
obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
discount_predictor = DiscountPredictor(hidden_dim, latent_dim)

In [10]:
world_model = WorldModel(rssm, obs_encoder, obs_decoder, reward_predictor, discount_predictor)
dreamer = DreamerPolicy(world_model)
dreamer_optimizer = DreamerOptimizer(dreamer)

In [11]:
observations = torch.zeros(15, 32, 3, 64, 64)
actions = torch.zeros(15,32,4)
rewards = torch.zeros(15, 32, 1)
nonterminals = torch.zeros(15, 32, 1)

In [12]:
dreamer_optimizer.representation_loss(observations, actions, rewards)

(tensor(8047.6289, grad_fn=<AddBackward0>),
 tensor(8046.6958, grad_fn=<NegBackward>),
 tensor(0.6472, grad_fn=<NegBackward>),
 tensor(2.8628, grad_fn=<AddBackward0>),
 Independent(Normal(loc: torch.Size([15, 32, 3, 64, 64]), scale: torch.Size([15, 32, 3, 64, 64])), 3))

In [13]:
hidden_states, latent_states, actions, rewards, values, discounts = dreamer.unroll(torch.zeros(1,3,64,64))

In [14]:
hidden_states, posteriors, priors, decoded, rewards, discounts = dreamer.unroll_with_posteriors(
    torch.zeros(15,32,3,64,64),
    torch.randn((15,32,4))
)

print(hidden_states.size())
print(posteriors.logits.size())
print(priors.logits.size())
print(decoded.logits.size())
print(rewards.logits.size())
print(discounts.logits.size())

torch.Size([15, 32, 64])
torch.Size([15, 32, 128])
torch.Size([15, 32, 128])
torch.Size([15, 32, 6, 64, 64])
torch.Size([15, 32, 2])
torch.Size([15, 32, 2])


In [15]:
import time

In [16]:
time.sleep(0.1)

In [17]:
import PIL
import torch
import torchvision


class RGBObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env, grayscale=False):
        super().__init__(env)
        self.grayscale = grayscale

    def observation(self, obs):
        if len(obs.shape) < 3:
            # does not have rgb obs
            obs = self.env.render(mode="rgb_array")
        # modify obs
        to_pil = torchvision.transforms.ToPILImage()
        resize = torchvision.transforms.Resize(
            (64, 64), interpolation=PIL.Image.BICUBIC
        )
        grayscale = torchvision.transforms.Grayscale()
        to_tensor = torchvision.transforms.ToTensor()
        obs = to_pil(torch.from_numpy(obs.copy().transpose(2, 0, 1)))
        if self.grayscale:
            obs = grayscale(obs)
        obs = resize(obs)  # 3, 64, 64
        obs = to_tensor(obs)
        return obs.float().unsqueeze(0).detach().numpy()


### Lunar Lander Dreamer Policy

In [18]:
class LunarLanderDreamerPolicy(DreamerPolicy):
    def __init__(self, world_model: WorldModel):
        super().__init__(world_model)

    def dist(self, action_probs: torch.Tensor) -> td.Distribution:
        dist = td.Categorical(probs=action_probs)
        return dist

    def act(self, latent_state: torch.Tensor):
        out = self.forward(latent_state)
        return np.squeeze(out["action"].detach().cpu().numpy()).item(), out


In [19]:
from typing import Dict
import pandas as pd

def dreamer_rollout(
    policy: DreamerPolicy, env_name: str = None, env=None, env_creation_fn=None
) -> Dict[str, np.array]:
    if env_name is None and env is None:
        raise ValueError("env_name or env must be provided!")
    if env is None:
        if env_creation_fn is None:
            env_creation_fn = gym.make
        env = env_creation_fn(env_name)
    done = False
    observations, actions, rewards, log_probs, values = ([], [], [], [], [])
    observation = env.reset()

    with torch.no_grad():
        while not done:
            obs = torch.from_numpy(observation).to(policy.device)

            hidden_state = policy.initialize_hidden_state(1)
            encoded = policy.world_model.encode_obs(obs)
            latent_state = policy.world_model.posterior(hidden_state, encoded)

            action, out = policy.act(latent_state.sample())
            value = out["value"]
            next_observation, reward, done, info = env.step(action)

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            values.append(value.detach().cpu().numpy())

            observation = next_observation
    env.close()
    if isinstance(action, int):
        return {
            "observations": np.array(observations),
            "actions": np.array(pd.get_dummies(np.array(actions))).astype(float),
            "rewards": np.array(rewards),
            "values": np.array(values),
        }
    return {
        "observations": np.array(observations),
        "actions": np.array(actions),
        "rewards": np.array(rewards),
        "values": np.array(values),
    }


In [20]:
from torch.utils.tensorboard import SummaryWriter
import os
from datetime import datetime

def get_tensorboard_logger(experiment_name: str, base_log_path: str = "tensorboard_logs"):
    log_path = "{}/{}_{}".format(
            base_log_path, experiment_name, datetime.now()
        )
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: tensorboard --logdir '{}'".format(
            full_log_path
        )
    )
    return train_writer

In [21]:
from tqdm import tqdm

In [22]:
obs_encoding_dim = 64
hidden_dim = 64
latent_dim = 64
action_dim = 4

recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
discount_predictor = DiscountPredictor(hidden_dim, latent_dim)

world_model = WorldModel(rssm, obs_encoder, obs_decoder, reward_predictor, discount_predictor)

dreamer_policy = LunarLanderDreamerPolicy(world_model)


In [23]:
gym.make("LunarLander-v2")

<TimeLimit<LunarLander<LunarLander-v2>>>

In [24]:
env = RGBObservationWrapper(gym.make("LunarLander-v2"))

In [25]:
env.reset()
env.close()



In [26]:
np.random.randint(0, 10, 20)

array([8, 4, 7, 8, 4, 4, 2, 0, 5, 6, 6, 1, 0, 5, 7, 9, 5, 8, 0, 6])

In [27]:
np.stack([np.zeros((15,3,64,64)),np.zeros((15,3,64,64)),np.zeros((15,3,64,64))], axis=1).shape

(15, 3, 3, 64, 64)

In [28]:
class ReplayMemory:
    def __init__(
        self,
    ):
        self.observations = []
        self.actions = []
        self.rewards = []

    def store_episode(self, observations: np.array, actions: np.array, rewards: np.array) -> None:
        self.observations.append(observations)
        self.actions.append(actions)
        self.rewards.append(rewards)

    def sample(self, num_samples: int = 1, horizon: int = 40) -> Tuple[np.array, np.array, np.array]:
        episode_indices = np.random.randint(0, len(self.observations), num_samples)

        observations = []
        actions = []
        rewards = []

        for episode_idx in episode_indices:
            observation = np.squeeze(self.observations[episode_idx])
            action = np.squeeze(self.actions[episode_idx])
            reward = np.squeeze(self.rewards[episode_idx])

            observation_len = observation.shape[0]

            start_idx = np.random.randint(0, observation_len - horizon)

            observations.append(observation[start_idx:start_idx + horizon])
            actions.append(action[start_idx:start_idx + horizon])
            rewards.append(reward[start_idx:start_idx + horizon])

        return np.stack(observations, axis=1), np.stack(actions, axis=1), np.stack(rewards, axis=1)


In [29]:
def initialize_replay(policy, env, num_episodes: int = 50):
    replay_memory = ReplayMemory()
    for _ in range(num_episodes):
        out = dreamer_rollout(policy, env=env)

        observations = out["observations"]
        actions = out["actions"]
        rewards = out["rewards"]

        replay_memory.store_episode(observations, actions, rewards)

    return replay_memory

In [30]:
obs_encoding_dim = 64
hidden_dim =  128
latent_dim = 128
action_dim = 4

recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
discount_predictor = DiscountPredictor(hidden_dim, latent_dim)

world_model = WorldModel(rssm, obs_encoder, obs_decoder, reward_predictor, discount_predictor)

dreamer_policy = LunarLanderDreamerPolicy(world_model)

device = torch.device('cuda')

dreamer_policy = dreamer_policy.to(device)

In [None]:
bar = tqdm(np.arange(50000))


env = RGBObservationWrapper(gym.make("LunarLander-v2"))

writer = get_tensorboard_logger("DreamerLunarLander")
replay_memory = initialize_replay(dreamer_policy, env, num_episodes=50)

optimizer = DreamerOptimizer(dreamer_policy, world_model_lr=0.001)

for i in bar:

    observations, actions, rewards = replay_memory.sample(50, horizon=50)

    torch_observations = torch.from_numpy(observations).to(dreamer_policy.device)
    torch_actions = torch.from_numpy(actions).float().to(dreamer_policy.device)
    torch_rewards = torch.from_numpy(rewards).float().unsqueeze(-1).to(dreamer_policy.device)

    optimizer.zero_grad()
    loss, image_loss, reward_loss, kl_loss, image_dist = optimizer.loss_fn(
        torch_observations,
        torch_actions,
        torch_rewards
    )

    loss.backward()
    # torch.nn.utils.clip_grad_norm_(dreamer_policy.world_model.parameters(), 100.0)
    optimizer.step()

    # metrics
    grad_dict = {}
    for n, W in dreamer_policy.world_model.named_parameters():
        if W.grad is not None:
            grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

    avg_reward = np.mean(np.sum(rewards, axis=0), axis=0)

    metrics_dict = {
        "loss":loss.item(),
        "image_loss":image_loss.item(),
        "reward_loss":reward_loss.item(),
        "kl_loss":kl_loss.item(),
        "avg_reward":avg_reward,
        **grad_dict
    }

    for key in metrics_dict:
        writer.add_scalar(key, metrics_dict[key], i)

    metric_string = ""
    for key in metrics_dict:
        metric_string += "{}: {}, ".format(key, metrics_dict[key])

    initial_observations = np.expand_dims(observations[:, 0], 0)
    predicted_observations = np.expand_dims(torch.clip(image_dist.sample()[:, 0], 0.0, 1.0).detach().cpu().numpy(), 0)

    writer.add_video(
        "initial_observations",
        initial_observations,
        global_step=i,
        fps=32
    )

    writer.add_video(
        "predicted_observations",
        predicted_observations,
        global_step=i,
        fps=32
    )

    bar.set_description(metric_string)


  0%|                                                                                                                                                                                                                                                                                                                                                                                                  | 0/50000 [00:00<?, ?it/s]

Follow tensorboard logs with: tensorboard --logdir '/home/kokkgoblin/Code/ez-rl/examples/tensorboard_logs/DreamerLunarLander_2022-04-30 23:04:08.539420'


loss: -12371.6240234375, image_loss: -12373.9306640625, reward_loss: 2.3055663108825684, kl_loss: 0.00717823626473546, avg_reward: -71.62826876996297, rssm.recurrent_model.rnn.weight_ih_grad: 49.337677001953125, rssm.recurrent_model.rnn.weight_hh_grad: 0.07300981879234314, rssm.recurrent_model.rnn.bias_ih_grad: -1.4645345211029053, rssm.recurrent_model.rnn.bias_hh_grad: -0.1030837893486023, rssm.representation_mod