In [1]:
import gym
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from typing import Tuple, Type

import torch.distributions as td
import torch
import torch.nn.functional as F
import abc

from ezrl.optimizer import RLOptimizer
from ezrl.policy import GymPolicy

import abc

### Dreamer Policy

In [2]:
from ezrl.algorithms.dreamer.components import RSSM, GammaPredictor, RewardPredictor, RecurrentModel, RepresentationModel, TransitionPredictor
from ezrl.algorithms.dreamer.utils import get_convs, get_deconvs, NormalDistributionModel
from ezrl.algorithms.dreamer.world_model import WorldModel


In [3]:
from ezrl.policy import ACPolicy
from typing import Dict, Any, Optional



class DreamerPolicy(ACPolicy):
    def __init__(
        self,
        world_model: WorldModel,
    ):
        super().__init__()
        self.world_model = world_model

        self.action_dim = world_model.action_dim
        self.latent_dim = world_model.latent_dim
        self.hidden_dim = world_model.hidden_dim
        self.obs_encoding_dim = world_model.obs_encoding_dim

        self.policy_net = nn.Linear(world_model.latent_dim, self.action_dim)
        self.critic_net = nn.Linear(world_model.latent_dim, 1)

        log_std = -0.5 * np.ones(self.action_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))

        self.__device_param_dummy__ = nn.Parameter(
            torch.empty(0)
        )  # to keep track of device

    @property
    def device(self) -> torch.device:
        return self.__device_param_dummy__.device

    def initialize_hidden_state(self, batch_size: int = 1) -> torch.Tensor:
        return self.world_model.rssm.initialize_hidden_state(batch_size).to(self.device)

    def log_prob(self, dist: td.Distribution, actions: torch.Tensor):
        if isinstance(dist, td.Categorical):
            return dist.log_prob(actions)
        return dist.log_prob(actions).sum(axis=-1)

    def dist(self, action_logits: torch.Tensor) -> td.Distribution:
        std = torch.exp(self.log_std)
        return td.normal.Normal(action_logits, std)

    def forward(self, latent_state: Any) -> Dict[str, Any]:
        mu = self.policy_net(latent_state)
        dist = self.dist(mu)
        action = dist.sample()
        log_probs = self.log_prob(dist, action)
        value = self.critic_net(latent_state).squeeze()
        return {"action":action, "dist":dist, "log_probs":log_probs, "value":value}

    def critic(self, latent_state:Any):
        return self.critic_net(latent_state).squeeze()

    def act(self, latent_state: Any):
        out = self.forward(latent_state)
        return np.squeeze(out["action"].detach().cpu().numpy()), out

    def unroll(self, initial_observation: torch.Tensor, num_steps: int = 15) -> Tuple[torch.Tensor, ...]:
        # observations size -- B x C x H x W
        batch_size = initial_observation.size(0)

        hidden_states = []
        latent_states = []
        actions = []
        rewards = []
        values = []
        discounts = []

        hidden_state = self.initialize_hidden_state(batch_size)
        encoded = self.world_model.encode_obs(initial_observation)
        latent_state, _ = self.world_model.posterior(hidden_state, encoded)

        for i in range(num_steps):
            latent_states.append(latent_state)

            out = self.forward(latent_state)

            action = out["action"]
            value = out["value"]
            reward, reward_dist = self.world_model.predict_reward(hidden_state, latent_state)
            discount, discount_dist = self.world_model.predict_gamma(hidden_state, latent_state)

            hidden_state = self.world_model.recurrent(hidden_state, latent_state, action)
            latent_state, latent_state_dist = self.world_model.prior(hidden_state)

            hidden_states.append(hidden_state)
            # latent_states.append(latent_state)
            actions.append(action)
            rewards.append(reward)
            values.append(value)
            discounts.append(discount)

        return torch.stack(hidden_states), torch.stack(latent_states), torch.stack(actions), torch.stack(rewards), torch.stack(values), torch.stack(discounts)

    def unroll_with_posteriors(self, observations: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # observations size -- T x B x C x H x W
        # actions size -- T x B x L
        num_steps = observations.size(0)
        batch_size = observations.size(1)

        hidden_states = []
        encoded_observations = []
        posteriors = []
        priors = []
        decoded_observations = []
        rewards = []
        discounts = []

        # h_0
        hidden_state = self.initialize_hidden_state(batch_size)

        for i in range(num_steps):
            hidden_states.append(hidden_state)

            # t
            encoded = self.world_model.encode_obs(observations[i])
            posterior, _ = self.world_model.posterior(hidden_state, encoded)
            prior, _ = self.world_model.prior(hidden_state)
            decoded, _ = self.world_model.decode_obs(hidden_state, posterior)
            reward, _ = self.world_model.predict_reward(hidden_state, posterior)
            discount, _ = self.world_model.predict_gamma(hidden_state, posterior)

            # h_t+1
            hidden_state = self.world_model.recurrent(hidden_state, posterior, actions[i])

            encoded_observations.append(encoded)
            posteriors.append(posterior)
            priors.append(prior)
            decoded_observations.append(decoded)
            rewards.append(reward)
            discounts.append(discount)

        return torch.stack(hidden_states), torch.stack(encoded_observations), torch.stack(posteriors), torch.stack(priors), torch.stack(decoded_observations), torch.stack(rewards), torch.stack(discounts)


NameError: name 'WorldModel' is not defined

In [312]:
from ezrl.optimizer import RLOptimizer

class DreamerOptimizer(RLOptimizer):
    def __init__(
        self,
        policy: DreamerPolicy,
        horizon: int = 15,
        kl_scale: float = 0.1,
        world_model_lr: float = 2e-4,
        actor_lr: float = 2e-5,
        critic_lr: float = 1e-4
    ):
        self.policy = policy
        self.horizon = horizon
        self.kl_scale = kl_scale

        self.world_model_lr = world_model_lr
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr

        self.setup_optimizer()

    def setup_optimizer(self):
        self.world_model_optimizer = optim.Adam(self.policy.world_model.parameters(), lr=self.world_model_lr)
        self.actor_optimizer = optim.Adam(self.policy.policy_net.parameters(), lr=self.actor_lr)
        self.critic_optimizer = optim.Adam(self.policy.critic_net.parameters(), lr=self.critic_lr)

# RGB LunarLander

# Setup 

In [58]:
obs_encoding_dim = 64
hidden_dim = 64
latent_dim = 64
action_dim = 4

### RSSM

In [339]:
recurrent_model = RecurrentModel(hidden_dim, latent_dim, action_dim)
representation_model = RepresentationModel(hidden_dim, latent_dim, obs_encoding_dim)
transition_predictor = TransitionPredictor(hidden_dim, latent_dim)

rssm = RSSM(recurrent_model, representation_model, transition_predictor)

### World Model

In [340]:
class LunarLanderRGBObsEncoder(nn.Module):
    def __init__(
        self,
        obs_encoding_dim: int,
        obs_dim: int = 8
    ):
        super().__init__()
        self.obs_encoding_dim = obs_encoding_dim
        self.obs_dim = obs_dim

        self.net = get_convs(obs_dim, 1, 3, self.obs_encoding_dim)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs).view(obs.size(0), -1)


class RGBObsDecoder(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int, obs_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.obs_dim = obs_dim
        self.input_channels = hidden_dim + latent_dim
        self.net = get_deconvs(1, obs_dim, self.input_channels, 6)


    def forward(self, hidden_state: torch.Tensor, latent_state: torch.Tensor) -> Tuple[torch.Tensor, td.Distribution]:
        inp = torch.cat([hidden_state, latent_state], dim=-1).view(hidden_state.size(0), self.input_channels, 1, 1)
        logits = self.net(inp)
        mean, std = torch.chunk(logits, 2, dim=1)
        std = F.softplus(std) + 0.1
        dist = td.Normal(mean, std)
        return dist.rsample(), dist

In [341]:
obs_encoder = LunarLanderRGBObsEncoder(obs_encoding_dim, 64)
obs_decoder = RGBObsDecoder(hidden_dim, latent_dim, 64)
reward_predictor = RewardPredictor(hidden_dim, latent_dim)
gamma_predictor = GammaPredictor(hidden_dim, latent_dim)

In [342]:
obs_encoder(torch.zeros(1, 3, 64, 64)).size()

torch.Size([1, 64])

In [343]:
obs_decoder(torch.zeros(1,64), torch.zeros(1,64))[0].size()

torch.Size([1, 3, 64, 64])

In [344]:
world_model = WorldModel(rssm, obs_encoder, obs_decoder, reward_predictor, gamma_predictor)

In [345]:
dreamer = DreamerPolicy(world_model)

In [346]:
hidden_states, latent_states, actions, rewards, values, discounts = dreamer.unroll(torch.zeros(1,3,64,64))

In [347]:
hidden_states, encoded_observations, posteriors, priors, decoded, rewards, discounts = dreamer.unroll_with_posteriors(
    torch.zeros(15,32,3,64,64),
    torch.randn((15,32,4))
) 

In [353]:
print(hidden_states.size())
print(encoded_observations.size())
print(posteriors.size())
print(priors.size())
print(decoded.size())
print(rewards.size())
print(discounts.size())

torch.Size([15, 32, 64])
torch.Size([15, 32, 64])
torch.Size([15, 32, 64])
torch.Size([15, 32, 64])
torch.Size([15, 32, 3, 64, 64])
torch.Size([15, 32, 1])
torch.Size([15, 32, 1])
