In [1]:
import gym
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from typing import Tuple, Type

import torch.distributions as td
import torch
import torch.nn.functional as F
import abc

from ezrl.optimizer import RLOptimizer
from ezrl.policy import GymPolicy

import abc

### Dreamer Policy

In [2]:
from ezrl.algorithms.dreamer.components import RSSM, DiscountPredictor, RewardPredictor, RecurrentModel, RepresentationModel, TransitionPredictor
from ezrl.algorithms.dreamer.utils import get_convs, get_deconvs, NormalDistribution
from ezrl.algorithms.dreamer.world_model import WorldModel

# RGB LunarLander

# Setup 

In [3]:
obs_encoding_dim = 64
hidden_dim = 64
latent_dim = 64
action_dim = 4

### World Model

In [4]:
class LunarLanderRGBObsEncoder(nn.Module):
    def __init__(
        self,
        obs_encoding_dim: int,
        obs_dim: int = 8
    ):
        super().__init__()
        self.obs_encoding_dim = obs_encoding_dim
        self.obs_dim = obs_dim

        self.net = get_convs(obs_dim, 1, 3, self.obs_encoding_dim)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs).view(obs.size(0), -1)


class RGBObsDecoder(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int, obs_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.obs_dim = obs_dim
        self.input_channels = hidden_dim + latent_dim
        self.net = get_deconvs(1, obs_dim, self.input_channels, 6)
        self.distribution = NormalDistribution


    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1).view(hidden_state.size(0), self.input_channels, 1, 1)
        logits = self.net(inp)
        return self.distribution(logits)

In [5]:
from ezrl.optimizer import RLOptimizer
from ezrl.algorithms.dreamer.policy import DreamerPolicy

from torch.distributions.kl import kl_divergence

class DreamerOptimizer(RLOptimizer):
    def __init__(
        self,
        policy: DreamerPolicy,
        horizon: int = 15,
        kl_beta: float = 0.1,
        kl_alpha: float = 0.8,
        world_model_lr: float = 2e-4,
        actor_lr: float = 2e-5,
        critic_lr: float = 1e-4
    ):
        self.policy = policy
        self.world_model = policy.world_model
        self.horizon = horizon
        self.kl_beta = kl_beta
        self.kl_alpha = kl_alpha

        self.world_model_lr = world_model_lr
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr

        self.setup_optimizer()

    def setup_optimizer(self):
        self.world_model_optimizer = optim.Adam(self.policy.world_model.parameters(), lr=self.world_model_lr)
        self.actor_optimizer = optim.Adam(self.policy.policy_net.parameters(), lr=self.actor_lr)
        self.critic_optimizer = optim.Adam(self.policy.critic_net.parameters(), lr=self.critic_lr)

    def representation_loss(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        nonterminals: torch.Tensor
    ):
        # shape should be T x B x shape, where T = timesteps, B = batch

        _, posteriors, priors, decoded_observations, predicted_rewards, predicted_discounts = self.policy.unroll_with_posteriors(observations, actions)

        # image loss
        # - lnp(x_t | h_t, z_t)
        image_dist = td.Independent(decoded_observations.dist(logit_dim=2), 1)
        image_loss = -torch.mean(image_dist.log_prob(observations))

        # reward loss
        reward_dist = td.Independent(predicted_rewards.dist(), 1)
        reward_loss = -torch.mean(reward_dist.log_prob(rewards))

        # discount loss
        discount_dist = td.Independent(predicted_discounts.dist(), 1)
        discount_loss = -torch.mean(discount_dist.log_prob(nonterminals))


        # KL loss
        kl_prior = kl_divergence(
            td.Independent(posteriors.dist(logits=posteriors.logits.detach()), 1),
            td.Independent(priors.dist(), 1)
        )
        kl_posterior = kl_divergence(
            td.Independent(posteriors.dist(), 1),
            td.Independent(priors.dist(logits=priors.logits.detach()), 1)
        )
        kl_loss = self.kl_alpha*torch.mean(kl_prior) + (1.0 - self.kl_alpha)*torch.mean(kl_posterior)

        return image_loss + reward_loss + discount_loss + self.kl_beta*kl_loss

    def zero_grad(self):
        pass

    def loss_fn(self):
        pass

    def step(self):
        pass

In [6]:
recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

In [7]:
obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
discount_predictor = DiscountPredictor(hidden_dim, latent_dim)

In [8]:
world_model = WorldModel(rssm, obs_encoder, obs_decoder, reward_predictor, discount_predictor)

In [9]:
dreamer = DreamerPolicy(world_model)

In [10]:
dreamer_optimizer = DreamerOptimizer(dreamer)

In [11]:
observations = torch.zeros(15, 32, 3, 64, 64)
actions = torch.zeros(15,32,4)
rewards = torch.zeros(15, 32, 1)
nonterminals = torch.zeros(15, 32, 1)

In [12]:
dreamer_optimizer.representation_loss(observations, actions, rewards, nonterminals)

tensor(47.4836, grad_fn=<AddBackward0>)

In [13]:
hidden_states, latent_states, actions, rewards, values, discounts = dreamer.unroll(torch.zeros(1,3,64,64))

In [14]:
hidden_states, posteriors, priors, decoded, rewards, discounts = dreamer.unroll_with_posteriors(
    torch.zeros(15,32,3,64,64),
    torch.randn((15,32,4))
) 

In [15]:
hidden_states

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.1597, -0.3902,  0.3930,  ..., -0.0427, -0.0534,  0.2580],
         [-0.0546,  0.2429,  0.3388,  ...,  0.1884, -0.2718,  0.0208],
         [-0.2143, -0.2296,  0.0388,  ..., -0.0649,  0.1539,  0.1711],
         ...,
         [-0.2278, -0.2504,  0.1835,  ...,  0.2579, -0.1190,  0.1198],
         [ 0.1663,  0.0622,  0.5049,  ...,  0.2101, -0.1024,  0.2112],
         [ 0.1251,  0.1553,  0.1113,  ..., -0.0945, -0.2539, -0.2374]],

        [[-0.0245, -0.0133,  0.2176,  ..., -0.4099, -0.2140, -0.1269],
         [-0.2297, -0.3966,  0.2965,  ...,  0

In [16]:
print(hidden_states.size())
print(posteriors.logits.size())
print(priors.logits.size())
print(decoded.logits.size())
print(rewards.logits.size())
print(discounts.logits.size())

torch.Size([15, 32, 64])
torch.Size([15, 32, 128])
torch.Size([15, 32, 128])
torch.Size([15, 32, 6, 64, 64])
torch.Size([15, 32, 2])
torch.Size([15, 32, 2])


In [17]:
import PIL
import torch
import torchvision
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from nes_py.wrappers import JoypadSpace


class RGBObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env, grayscale=False):
        super().__init__(env)
        self.grayscale = grayscale

    def observation(self, obs):
        if len(obs.shape) < 3:
            # does not have rgb obs
            obs = self.env.render(mode="rgb_array")
        # modify obs
        to_pil = torchvision.transforms.ToPILImage()
        resize = torchvision.transforms.Resize(
            (64, 64), interpolation=PIL.Image.BICUBIC
        )
        grayscale = torchvision.transforms.Grayscale()
        to_tensor = torchvision.transforms.ToTensor()
        obs = to_pil(torch.from_numpy(obs.copy().transpose(2, 0, 1)))
        if self.grayscale:
            obs = grayscale(obs)
        obs = resize(obs)  # 3, 64, 64
        obs = to_tensor(obs)
        return obs.float().unsqueeze(0).detach().numpy()


In [18]:
env = RGBObservationWrapper(gym.make("LunarLander-v2"))

In [19]:
obs = env.reset()



In [20]:
obs.shape

(1, 3, 64, 64)

### Lunar Lander Dreamer Policy

In [47]:
class LunarLanderDreamerPolicy(DreamerPolicy):
    def __init__(self, world_model: WorldModel):
        super().__init__(world_model)

    def dist(self, action_probs: torch.Tensor) -> td.Distribution:
        dist = td.Categorical(probs=action_probs)
        return dist

    def act(self, latent_state: torch.Tensor):
        out = self.forward(latent_state)
        return np.squeeze(out["action"].detach().cpu().numpy()).item(), out


In [48]:
obs_encoding_dim = 64
hidden_dim = 64
latent_dim = 64
action_dim = 4

recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
discount_predictor = DiscountPredictor(hidden_dim, latent_dim)

world_model = WorldModel(rssm, obs_encoder, obs_decoder, reward_predictor, discount_predictor)

dreamer_policy = LunarLanderDreamerPolicy(world_model)

In [49]:
from typing import Dict
import pandas as pd

def dreamer_rollout(
    policy: DreamerPolicy, env_name: str = None, env=None, env_creation_fn=None
) -> Dict[str, np.array]:
    if env_name is None and env is None:
        raise ValueError("env_name or env must be provided!")
    if env is None:
        if env_creation_fn is None:
            env_creation_fn = gym.make
        env = env_creation_fn(env_name)
    done = False
    observations, actions, rewards, log_probs, values = ([], [], [], [], [])
    observation = env.reset()

    with torch.no_grad():
        while not done:
            obs = torch.from_numpy(observation).to(policy.device)

            hidden_state = policy.initialize_hidden_state(1)
            encoded = policy.world_model.encode_obs(obs)
            latent_state = policy.world_model.posterior(hidden_state, encoded)

            action, out = policy.act(latent_state.sample())
            value = out["value"]
            next_observation, reward, done, info = env.step(action)

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)
            values.append(value.detach().cpu().numpy())

            observation = next_observation
    env.close()
    if isinstance(action, int):
        return {
            "observations": np.array(observations),
            "actions": np.array(pd.get_dummies(np.array(actions))).astype(float),
            "rewards": np.array(rewards),
            "values": np.array(values),
        }
    return {
        "observations": np.array(observations),
        "actions": np.array(actions),
        "rewards": np.array(rewards),
        "values": np.array(values),
    }


In [50]:
out = dreamer_rollout(dreamer_policy, env=env)



In [52]:
out['actions'].shape

(89, 4)