In [2]:
import gym
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from typing import Tuple, Type

import torch.distributions as td
import torch
import torch.nn.functional as F
import abc

from ezrl.optimizer import RLOptimizer
from ezrl.policy import GymPolicy

import abc

In [50]:
from typing import Iterable

import torch
from torch.nn import Module


def get_convs(
    initial_size,
    output_size,
    initial_channels,
    out_channels,
    channel_dimension=32,
    bias=True,
):
    size = initial_size
    convs = []
    if size == output_size:
        convs.append(torch.nn.Conv2d(initial_channels, out_channels, 1, bias=bias))
    in_channels = initial_channels
    while size > output_size:
        if (size // 2) == output_size:
            convs.append(
                torch.nn.Conv2d(in_channels, out_channels, 3, 2, padding=1, bias=bias)
            )
        else:
            convs.append(
                torch.nn.Conv2d(
                    in_channels, channel_dimension, 3, 2, padding=1, bias=bias
                )
            )
            convs.append(torch.nn.ELU())
            in_channels = channel_dimension
        size = size // 2
    return torch.nn.Sequential(*convs)

def get_deconvs(
    initial_size,
    output_size,
    initial_channels,
    out_channels,
    channel_dimension=32,
    bias=True,
):
    size = initial_size
    deconvs = []
    if size == output_size:
        deconvs.append(torch.nn.Conv2d(initial_channels, out_channels, 1, bias=bias))
    in_channels = initial_channels
    while size < output_size:
        if (size * 2) == output_size:
            deconvs.append(
                torch.nn.ConvTranspose2d(
                    in_channels,
                    out_channels,
                    3,
                    2,
                    padding=1,
                    output_padding=1,
                    bias=bias,
                )
            )
        else:
            deconvs.append(
                torch.nn.ConvTranspose2d(
                    in_channels,
                    channel_dimension,
                    3,
                    2,
                    padding=1,
                    output_padding=1,
                    bias=bias,
                )
            )
            deconvs.append(torch.nn.ELU())
            in_channels = channel_dimension
        size = size * 2
    return torch.nn.Sequential(*deconvs)

### Recurrent model

In [3]:
import torch.nn as nn
import torch

class RecurrentModel(nn.Module):
    """
    Defined as:
        h_t = f(h_t-1, z_t-1, a_t-1)

        h_t: output hidden state at timestep t

        f: rnn model
        h_t-1: previous hidden state
        z_t-1: previous latent state
        a_t-1: previous action
    """
    def __init__(self, hidden_dim: int, latent_dim: int, action_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.action_dim = action_dim
        self.recurrent_input_dim = self.action_dim + self.latent_dim
        self.rnn = nn.GRUCell(self.recurrent_input_dim, self.hidden_dim)

    def initialize_state(self, batch_size: int = 1, *args, **kwargs) -> torch.Tensor:
        return torch.zeros(batch_size, self.hidden_dim, *args, **kwargs)

    def forward(
        self,
        prev_hidden: torch.Tensor,
        prev_latent_state: torch.Tensor,
        prev_action: torch.Tensor
    ) -> torch.Tensor:
        inp = torch.cat([prev_action, prev_latent_state], dim=-1)
        return self.rnn(inp, prev_hidden)


### Distribution Wrapper

In [4]:
class DistributionModel(nn.Module, metaclass=abc.ABCMeta):
    def __init__(self):
        super().__init__()

    @abc.abstractmethod
    def forward(self, logits: torch.Tensor):
        """
        Sample from a distribution
        """

class NormalDistributionModel(DistributionModel):
    def __init__(self, logit_net: nn.Module):
        super().__init__()
        self.logit_net = logit_net

    def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, td.Distribution]:
        logits = self.logit_net(*args, **kwargs)
        mean, std = torch.chunk(logits, 2, dim=-1)
        std = F.softplus(std) + 0.1
        dist = td.Normal(mean, std)
        return dist.rsample(), dist

### Backend Module

This is to make it easier to switch out models (could be linear or convolutional, etc.)

In [7]:
class BackendModule(nn.Module, metaclass=abc.ABCMeta):
    def __init__(self):
        super().__init__()

    @abc.abstractmethod
    def forward(self, logits: torch.Tensor):
        """
        Sample from a distribution
        """


In [8]:
class LinearBackendModule(BackendModule):
    def __init__(self, input_dims: int, output_dims: int):
        super().__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.net = nn.Sequential(
            nn.Linear(input_dims, output_dims),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x).view(x.size(0), -1)


#### RSSM

In [60]:
class RepresentationModel(nn.Module):
    """
    Posterior Model

    Defined as:
        z_t ~ q(z_t | h_t, x_t)

        z_t: posterior latent state at timestep t

        q: posterior distribution to sample latent state from
        h_t: current hidden state
        x_t: current observation encoding
    """
    def __init__(
        self,
        hidden_dim: int,
        latent_dim: int,
        obs_encoding_dim: int,
        backend_module: BackendModule = LinearBackendModule,
        distribution_model: DistributionModel = NormalDistributionModel
    ):
        super().__init__()
        self.obs_encoding_dim = obs_encoding_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.backend_module = backend_module
        self.distribution_model = distribution_model
        self.net = distribution_model(
            nn.Sequential(
                backend_module(obs_encoding_dim + hidden_dim, 32),
                nn.Tanh(),
                nn.Linear(32, self.latent_dim*2)
            )
        )

    def forward(self, hidden_state: torch.Tensor, obs_encoding: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, obs_encoding], dim=-1)
        return self.net(inp)

class TransitionPredictor(nn.Module):
    """
    Prior Model

    Defined as:
        zhat_t ~ p(zhat_t | h_t)

        zhat_t: prior latent state at timestep t

        p: prior distribution to sample latent state from
        h_t: current hidden state
    """

    def __init__(
        self,
        hidden_dim: int,
        latent_dim: int,
        backend_module: BackendModule = LinearBackendModule,
        distribution_model: DistributionModel = NormalDistributionModel
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.backend_module = backend_module
        self.distribution_model = distribution_model
        self.net = distribution_model(
            nn.Sequential(
                backend_module(self.hidden_dim, 32),
                nn.Tanh(),
                nn.Linear(32, self.latent_dim*2)
            )
        )

    def forward(self, hidden_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        return self.net(hidden_state)

class RSSM(nn.Module):
    def __init__(
        self,
        recurrent_model: RecurrentModel,
        representation_model: RepresentationModel,
        transition_predictor: TransitionPredictor,
    ):
        super().__init__()

        self.action_dim = recurrent_model.action_dim
        self.latent_dim = recurrent_model.latent_dim
        self.hidden_dim = recurrent_model.hidden_dim
        self.obs_encoding_dim = representation_model.obs_encoding_dim

        # h_t = f(h_t-1, z_t-1, a_t-1)
        self.recurrent_model: RecurrentModel = recurrent_model

        # posterior, z_t ~ q(z_t | h_t, x_t)
        self.representation_model: RepresentationModel = representation_model
        # prior, zhat_t ~ p(zhat_t | h_t)
        self.transition_predictor: TransitionPredictor = transition_predictor

    def initialize_hidden_state(self, batch_size: int = 1, *args, **kwargs) -> torch.Tensor:
        return self.recurrent_model.initialize_state(batch_size, self.hidden_dim, *args, **kwargs)

    def prior(self, hidden_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        return self.transition_predictor(hidden_state)

    def posterior(self, hidden_state: torch.Tensor, obs_encoding: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        return self.representation_model(hidden_state, obs_encoding)

    def recurrent(self, prev_hidden_state: torch.Tensor, prev_latent_state: torch.Tensor, prev_action: torch.Tensor) -> torch.Tensor:
        return self.recurrent_model(prev_hidden_state, prev_latent_state, prev_action)


## World Model

In [10]:
class ObsDecoder(nn.Module):
    """
    Observation Decoder

    Defined as:
        xhat_t ~ p(xhat_t | h_t, z_t)

        xhat_t: posterior latent state at timestep t

        p: prior distribution to decode obs from latent state
        h_t: current hidden state
        z_t: current latent state
    """
    def __init__(self, obs_decoder_model: nn.Module, hidden_dim: int, latent_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.net = NormalDistributionModel(
            obs_decoder_model
        )

    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1)
        return self.net(inp)

In [21]:
class RewardPredictor(nn.Module):
    """
    Observation Decoder

    Defined as:
        rhat_t ~ p(rhat_t | h_t, z_t)

        rhat_t: prior reward prediction at timestep t

        p: prior distribution to predict rewards
        h_t: current hidden state
        z_t: current latent state
    """
    def __init__(
        self,
        hidden_dim: int,
        latent_dim: int,
        backend_module: BackendModule = LinearBackendModule,
        distribution_model: DistributionModel = NormalDistributionModel
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.backend_module = backend_module
        self.distribution_model = distribution_model
        self.net = distribution_model(
            nn.Sequential(
                backend_module(latent_dim + hidden_dim, 32),
                nn.Tanh(),
                nn.Linear(32, 1)
            )
        )

    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1)
        return self.net(inp)

In [20]:
class GammaPredictor(nn.Module):
    """
    Observation Decoder

    Defined as:
        gammahat_t ~ p(rhat_t | h_t, z_t)

        gammahat_t: prior reward prediction at timestep t

        p: prior distribution to predict rewards
        h_t: current hidden state
        z_t: current latent state
    """
    def __init__(
        self,
        hidden_dim: int,
        latent_dim: int,
        backend_module: BackendModule = LinearBackendModule,
        distribution_model: DistributionModel = NormalDistributionModel
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.backend_module = backend_module
        self.distribution_model = distribution_model
        self.net = distribution_model(
            nn.Sequential(
                backend_module(latent_dim + hidden_dim, 32),
                nn.Tanh(),
                nn.Linear(32, 1)
            )
        )

    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1)
        return self.net(inp)

In [52]:
class WorldModel(nn.Module):
    def __init__(
        self,
        rssm: RSSM,
        obs_encoder: nn.Module,
        obs_decoder: nn.Module,
        reward_predictor: RewardPredictor,
        gamma_predictor: GammaPredictor
    ):
        super().__init__()
        self.rssm = rssm
        self.action_dim = rssm.action_dim
        self.latent_dim = rssm.latent_dim
        self.hidden_dim = rssm.hidden_dim
        self.obs_encoding_dim = rssm.obs_encoding_dim
        self.hidden_state = rssm.hidden_state

        self.obs_encoder = obs_encoder

        # xhat_t ~ p(xhat_t | h_t, z_t)
        self.obs_decoder = obs_decoder

        # rhat_t ~ p(rhat_t | h_t, z_t)
        self.reward_predictor = reward_predictor

        # gammahat_t ~ p(rhat_t | h_t, z_t)
        self.gamma_predictor = gamma_predictor

    def encode_obs(self, obs: torch.Tensor) -> torch.Tensor:
        encoded_obs = self.obs_encoder(obs)
        return encoded_obs

    def recurrent(self, prev_hidden_state: torch.Tensor, prev_latent_state: torch.Tensor, prev_action: torch.Tensor) -> torch.Tensor:
        """
        h_t = f(h_t-1, z_t-1, a_t-1)

        Args:
            prev_hidden_state (torch.Tensor): _description_
            prev_latent_state (torch.Tensor): _description_
            prev_action (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        return self.rssm.recurrent(prev_hidden_state, prev_latent_state, prev_action)

    def posterior(self, hidden_state: torch.Tensor, obs_encoding: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        """
        z_t ~ q(z_t | h_t, x_t)

        Args:
            obs_encoding (torch.Tensor): _description_
            hidden_state (torch.Tensor): _description_

        Returns:
            Tuple[torch.Tensor, td.Distribution]: _description_
        """
        return self.rssm.posterior(obs_encoding, hidden_state)

    def prior(self, hidden_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        """
        zhat_t ~ p(zhat_t | h_t)

        Args:
            hidden_state (torch.Tensor): _description_

        Returns:
            Tuple[torch.Tensor, td.Distribution]: _description_
        """
        return self.rssm.prior(hidden_state)

    def decode_obs(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        """
        xhat_t ~ p(xhat_t | h_t, z_t)

        Args:
            hidden_state (torch.Tensor): _description_
            latent_state (torch.Tensor): _description_

        Returns:
            Tuple[torch.Tensor, td.Distribution]: _description_
        """
        return self.obs_decoder(hidden_state, latent_state)

    def predict_reward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        """
        rhat_t ~ p(rhat_t | h_t, z_t)

        Args:
            hidden_state (torch.Tensor): _description_
            latent_state (torch.Tensor): _description_

        Returns:
            Tuple[torch.Tensor, td.Distribution]: _description_
        """
        return self.reward_predictor(hidden_state, latent_state)

    def predict_gamma(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        """
        gammahat_t ~ p(rhat_t | h_t, z_t)

        Args:
            hidden_state (torch.Tensor): _description_
            latent_state (torch.Tensor): _description_

        Returns:
            Tuple[torch.Tensor, td.Distribution]: _description_
        """
        return self.gamma_predictor(latent_state, hidden_state)


### Dreamer Policy

In [54]:
from ezrl.policy import ACPolicy
from typing import Dict, Any, Optional

class DreamerPolicy(ACPolicy):
    def __init__(
        self,
        world_model: WorldModel,
    ):
        self.world_model = world_model

        self.action_dim = world_model.action_dim
        self.latent_dim = world_model.latent_dim
        self.hidden_dim = world_model.hidden_dim
        self.obs_encoding_dim = world_model.obs_encoding_dim

        self.policy_net = nn.Linear(world_model.latent_dim, self.action_dim)
        self.critic_net = nn.Linear(world_model.latent_dim, 1)

        log_std = -0.5 * np.ones(self.action_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))

        self.__device_param_dummy__ = nn.Parameter(
            torch.empty(0)
        )  # to keep track of device

    @property
    def device(self) -> torch.device:
        return self.__device_param_dummy__.device

    def initialize_hidden_state(self, batch_size: int = 1) -> torch.Tensor:
        return self.world_model.rssm.initialize_hidden_state(batch_size).to(self.device)

    def log_prob(self, dist: td.Distribution, actions: torch.Tensor):
        if isinstance(dist, td.Categorical):
            return dist.log_prob(actions)
        return dist.log_prob(actions).sum(axis=-1)

    def dist(self, action_logits: torch.Tensor) -> td.Distribution:
        std = torch.exp(self.log_std)
        return td.normal.Normal(action_logits, std)

    def forward(self, latent_state: Any) -> Dict[str, Any]:
        mu = self.policy_net(latent_state)
        dist = self.dist(mu)
        action = dist.sample()
        log_probs = self.log_prob(dist, action)
        value = self.critic_net(latent_state).squeeze()
        return {"action":action, "dist":dist, "log_probs":log_probs, "value":value}

    def critic(self, latent_state:Any):
        return self.critic_net(latent_state).squeeze()

    def act(self, latent_state: Any):
        out = self.forward(latent_state)
        return np.squeeze(out["action"].detach().cpu().numpy()), out

    def unroll(self, initial_observation: torch.Tensor, num_steps: int = 15) -> Tuple[torch.Tensor, ...]:
        # observations size -- B x C x H x W
        batch_size = initial_observation.size(0)

        hidden_states = []
        latent_states = []
        actions = []
        rewards = []
        values = []
        discounts = []

        hidden_state = self.initialize_hidden_state(batch_size)
        encoded = self.encode_obs(initial_observation)
        latent_state = self.world_model.posterior(hidden_state, encoded)

        for i in range(num_steps):
            latent_states.append(latent_state)

            out = self.forward(latent_state)

            action = out["action"]
            value = out["value"]
            reward = self.world_model.predict_reward(hidden_state, latent_state)
            discount = self.world_model.gamma(hidden_state, latent_state)

            hidden_state = self.world_model.recurrent(hidden_state, latent_state, action)
            latent_state = self.world_model.prior(hidden_state)

            hidden_states.append(hidden_state)
            latent_states.append(latent_state)
            actions.append(action)
            rewards.append(reward)
            values.append(value)
            discounts.append(discount)

        return torch.stack(hidden_states), torch.stack(latent_states), torch.stack(actions), torch.stack(rewards), torch.stack(values), torch.stack(discounts)

    def unroll_with_posteriors(self, observations: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # observations size -- T x B x C x H x W
        # actions size -- T x B x L
        num_steps = observations.size(0)
        batch_size = observations.size(1)

        hidden_states = []
        encoded_observations = []
        posteriors = []
        priors = []
        decoded_observations = []
        rewards = []
        discounts = []

        # h_0
        hidden_state = self.initialize_hidden_state(batch_size)

        for i in range(num_steps):
            hidden_states.append(hidden_state)

            # t
            encoded = self.encode_obs(observations[i])
            posterior = self.world_model.posterior(hidden_state, encoded)
            prior = self.world_model.prior(hidden_state)
            decoded = self.world_model.decode_obs(hidden_state, posterior)
            reward = self.world_model.predict_reward(hidden_state, posterior)
            discount = self.world_model.gamma(hidden_state, posterior)

            # h_t+1
            hidden_state = self.world_model.recurrent(hidden_state, posterior, actions[i])

            hidden_states.append(hidden_state)
            encoded_observations.append(encoded)
            posteriors.append(posterior)
            priors.append(prior)
            decoded_observations.append(decoded)
            rewards.append(reward)
            discounts.append(discount)

        return torch.stack(hidden_states), torch.stack(encoded_observations), torch.stack(posteriors), torch.stack(priors), torch.stack(decoded_observations), torch.stack(rewards), torch.stack(discounts)


In [55]:
from ezrl.optimizer import RLOptimizer

class DreamerOptimizer(RLOptimizer):
    def __init__(
        self,
        policy: DreamerPolicy,
        horizon: int = 15,
        kl_scale: float = 0.1,
        world_model_lr: float = 2e-4,
        actor_lr: float = 2e-5,
        critic_lr: float = 1e-4
    ):
        self.policy = policy
        self.horizon = horizon
        self.kl_scale = kl_scale

        self.world_model_lr = world_model_lr
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr

        self.setup_optimizer()

    def setup_optimizer(self):
        self.world_model_optimizer = optim.Adam(self.policy.world_model.parameters(), lr=self.world_model_lr)
        self.actor_optimizer = optim.Adam(self.policy.policy_net.parameters(), lr=self.actor_lr)
        self.critic_optimizer = optim.Adam(self.policy.critic_net.parameters(), lr=self.critic_lr)

# RGB LunarLander

In [41]:
zeros = torch.zeros((1, 3, 64, 64))

In [42]:
observation_encoder = get_convs(64, 1, 3, 12)

In [48]:
params = sum(p.numel() for p in observation_encoder.parameters() if p.requires_grad)
params

41356

In [44]:
o = observation_encoder(zeros).view(1, 12)

In [39]:
observation_decoder = get_deconvs(1, 64, 12, 3)

In [46]:
observation_decoder(o.view(1,12,1,1)).size()

torch.Size([1, 3, 64, 64])

# Setup 

In [58]:
obs_encoding_dim = 64
hidden_dim = 64
latent_dim = 64
action_dim = 4

### RSSM

In [61]:
recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

### World Model

In [79]:
class LunarLanderRGBObsEncoder(nn.Module):
    def __init__(
        self,
        obs_encoding_dim: int,
        obs_dim: int = 8
    ):
        super().__init__()
        self.obs_encoding_dim = obs_encoding_dim
        self.obs_dim = obs_dim

        self.net = get_convs(obs_dim, 1, 3, self.obs_encoding_dim)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs).view(obs.size(0), -1)


class RGBObsDecoder(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int, obs_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.obs_dim = obs_dim
        self.input_channels = hidden_dim + latent_dim
        self.net = get_deconvs(1, obs_dim, self.input_channels, 6)


    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1).view(hidden_state.size(0), self.input_channels, 1, 1)
        logits = self.net(inp)
        mean, std = torch.chunk(logits, 2, dim=1)
        std = F.softplus(std) + 0.1
        dist = td.Normal(mean, std)
        return dist.rsample(), dist

In [80]:
obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
gamma_predictor = GammaPredictor(hidden_dim, latent_dim)

In [81]:
obs_encoder(torch.zeros(1, 3, 64, 64)).size()

torch.Size([1, 64])

In [82]:
obs_decoder(torch.zeros(1,64), torch.zeros(1,64))[0].size()

torch.Size([1, 3, 64, 64])