In [1]:
from __future__ import annotations
import torch
from torch import Tensor
import torch.nn as nn
import torch.distributions as dist
from typing import List, NamedTuple


class MLP(nn.Module):
    def __init__(self, dims: List[int]):
        super().__init__()

        layers = []
        for idx, (d_in, d_out) in enumerate(zip(dims[:-1], dims[1:])):
            if idx > 0:
                layers.append(nn.ReLU())
            layers.append(nn.Linear(d_in, d_out, bias=True))
        
        self.main = nn.Sequential(*layers)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.main(x)

class VisualEncoder(nn.Module):
    def __init__(self, obs_shape: torch.Size, kernel_size=3, hidden_dim=32):
        super().__init__()
        self.obs_shape = obs_shape
        self.kernel_size = k = kernel_size
        self.hidden_dim = h = hidden_dim

        c, W, H = self.obs_shape
        assert W % 16 == 0 and H % 16 == 0, \
            "image resolution should be divisible by 16"
        assert k % 2 == 1, \
            "kernel_size should be an odd number"
        p = k // 2

        final_size = torch.Size([H // 16, W // 16, 8*h])
        self.out_features = final_size.numel()

        self.main = nn.Sequential(
            nn.Conv2d(c, h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(h, 2*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(2*h, 4*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(4*h, 8*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
        )
    
    def forward(self, x: Tensor) -> Tensor:
        return self.main(x)

class NormalLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.out_features = out_features
        self.fc = nn.Linear(in_features, 2*out_features)
    
    def forward(self, x: Tensor) -> dist.Distribution:
        params = self.fc(x)
        mean, std = params.split(self.out_features)
        res_dist = dist.Normal(mean, std)
        res_dist = dist.Independent(res_dist, 1)
        return res_dist

class VisualDecoder(nn.Module):
    def __init__(self, in_features: int, obs_shape: torch.Size, kernel_size=3, hidden_dim=32):
        super().__init__()
        self.obs_shape = obs_shape
        self.kernel_size = k = kernel_size
        self.hidden_dim = h = hidden_dim

        c, self.W, self.H = self.obs_shape
        p = k // 2

        self.main = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(8*h, 4*h, k, 1, p),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(4*h, 2*h, k, 1, p),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(2*h, h, k, 1, p),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(h, c, k, 1, p),
        )
    
    def forward(self, x: Tensor) -> dist.Distribution:
        x = x.reshape(len(x), -1, self.W // 16, self.H // 16)
        obs_dist = dist.Normal(self.main(x), 1.0)
        # By default, we'd get a single normal distribution over [B, C, H, W], 
        # whereas we actually want B independent normal distributions over
        # [C, H, W] each. This is what Independent is for.
        obs_dist = dist.Independent(obs_dist, 3) 
        return obs_dist

class RSSMState(NamedTuple):
    deter: Tensor
    stoch: Tensor

    def __getitem__(self, idx):
        return RSSMState(self.deter[idx], self.stoch[idx])
    
    def __setitem__(self, idx, value: RSSMState):
        self.deter[idx] = value.deter
        self.stoch[idx] = value.stoch

    @staticmethod
    def zeros(*size: int):
        return RSSMState(deter=torch.zeros(*size), stoch=torch.zeros(*size))

    def clone(self):
        return RSSMState(self.deter.clone(), self.stoch.clone())

class RSSMStateDist(NamedTuple):
    deter: Tensor
    stoch_dist: dist.Distribution

    def sample_n(self, n: int) -> RSSMState:
        return self.sample((n,))

    def sample(self, sample_size: torch.Size = torch.Size()) -> RSSMState:
        deter = self.deter.expand(*sample_size, *self.deter.shape)
        stoch = self.stoch_dist.sample(sample_size)
        return RSSMState(deter, stoch)
    
    def rsample(self, sample_size: torch.Size = torch.Size()) -> RSSMState:
        deter = self.deter.expand(*sample_size, *self.deter.shape)
        stoch = self.stoch_dist.rsample(sample_size)
        return RSSMState(deter, stoch)
    
    def log_prob(self, state: RSSMState):
        return self.stoch_dist.log_prob(state.stoch)

    def detach(self) -> RSSMStateDist:
        deter = self.deter.detach()
        if isinstance(self.stoch_dist, dist.Normal):
            stoch_loc = self.stoch_dist.loc.detach()
            stoch_scale = self.stoch_dist.scale.detach()
            stoch_dist = dist.Normal(stoch_loc, stoch_scale)
        else:
            raise NotImplementedError(f"Cannot detach {type(self.stoch_dist)}")

        return RSSMStateDist(deter, stoch_dist)

@dist.register_kl(RSSMStateDist, RSSMStateDist)
def kl_divergence(p: RSSMStateDist, q: RSSMStateDist):
    return dist.kl_divergence(p.stoch_dist, q.stoch_dist)

class EnvStateBatch(NamedTuple):
    obs: Tensor
    act: Tensor
    reward: Tensor

    @property
    def shape(self):
        return self.obs.shape

    def __iter__(self):
        yield from zip(self.obs, self.act, self.reward)

class RSSMCell(nn.Module):
    def __init__(self, in_features: int, deter_dim: int, stoch_dim: int):
        super().__init__()
        self.deter_dim = deter_dim
        self.stoch_dim = stoch_dim

        deter_input_size = deter_dim + stoch_dim + in_features
        self.deter_state_model = nn.GRUCell(deter_input_size, deter_dim)
        self.stoch_state_model = nn.Sequential(
            MLP([deter_dim, 256]), 
            nn.ReLU(),
            NormalLinear(256, stoch_dim),
        )
    
    def forward(self, h: RSSMState, x: Tensor) -> RSSMStateDist:
        deter_x = torch.cat([h.deter, h.stoch, x], dim=1)
        deter: Tensor = self.deter_state_model(deter_x)
        stoch_dist = self.stoch_state_model(deter)
        return RSSMStateDist(deter, stoch_dist)


class PlaNetRSSM(nn.Module):
    def __init__(self, obs_shape: torch.Size, act_shape: torch.Size, deter_dim: int, stoch_dim: int):
        super().__init__()
        self.deter_dim = deter_dim
        self.stoch_dim = stoch_dim
        
        self.vis_encoder = VisualEncoder(obs_shape)
        repr_model_input = self.vis_encoder.out_features + act_shape.numel()
        self.repr_cell = RSSMCell(repr_model_input, deter_dim, stoch_dim)
        trans_model_input = act_shape.numel()
        self.trans_cell = RSSMCell(trans_model_input, deter_dim, stoch_dim)
        self.vis_decoder = VisualDecoder(deter_dim + stoch_dim, obs_shape)
        self.reward_net = nn.Sequential(
            MLP([deter_dim + stoch_dim, 256]),
            nn.ReLU(),
            NormalLinear(256, 1),
        )
    
    def repr_model(self, h: RSSMState, obs: Tensor, act: Tensor) -> RSSMStateDist:
        obs_z = self.vis_encoder(obs)
        act_z = act.reshape(len(act), -1)
        repr_x = torch.cat([obs_z, act_z], dim=1)
        return self.repr_cell(h, repr_x)

    def trans_model(self, h: RSSMState, act: Tensor) -> RSSMStateDist:
        act_z = act.reshape(len(act), -1)
        trans_x = act_z
        return self.trans_cell(h, trans_x)

    def obs_model(self, h: RSSMState) -> dist.Distribution:
        state = torch.stack([h.deter, h.stoch], dim=1)
        return self.vis_decoder(state)

    def reward_model(self, h: RSSMState) -> dist.Distribution:
        state = torch.cat([h.deter, h.stoch], 1)
        return self.reward_net(state)

    def loss(self, batch: EnvStateBatch):
        seq_len, batch_size = batch.shape[:2]

        cur_state = RSSMState.zeros(batch_size)
        repr_dists: List[RSSMStateDist | None] = [None]
        states: List[RSSMState] = [cur_state]

        # Predicting states with observations + Reconstruction loss
        recon_loss = 0.0
        for obs, act, reward in batch:
            # Compute the informed state distribution
            repr_dist: RSSMStateDist = self.repr_model(cur_state, obs, act)

            # NOTE: Not sure whether to detach prior or not
            # Dreamer v2, in "KL balancing" section, mixes both
            # In PlaNet paper, it is suggested in "Latent overshooting"
            # section that the repr_dist is detached for overshooting dist > 1
            # whereas here we stop it altogether from the get-go.
            repr_dists.append(repr_dist.detach())
            # repr_dists.append(repr_dist)

            # Sample a state from the distribution
            # NOTE: sample() or rsample() ?
            cur_state = repr_dist.rsample()
            states.append(cur_state)

            # Evaluate observation and reward recon losses
            obs_dist = self.obs_model(cur_state)
            obs_loss = -obs_dist.log_prob(obs)
            reward_dist = self.reward_model(cur_state)
            reward_loss = -reward_dist.log_prob(reward)
            recon_loss += obs_loss + reward_loss
        
        # Overshooting
        trans_states = [x.clone() for x in states]
        trans_loss = 0.0
        for steps_ahead in range(1, seq_len):
            next_trans_states = []
            for step_idx in range(steps_ahead-1, seq_len):
                # Compute transition-model next state distribution
                trans_h: RSSMState = trans_states[step_idx]
                act = batch.act[step_idx]
                trans_dist: RSSMStateDist = self.trans_model(trans_h, act)

                # Sample a state from said distribution
                trans_state = trans_dist.rsample()
                next_trans_states.append(trans_state)

                # "Latent overshooting"
                repr_dist: RSSMStateDist = repr_dists[step_idx+1] # type: ignore
                trans_loss += dist.kl_divergence(trans_dist, repr_dist)

                # # "Observation overshooting"
                # obs_dist = self.obs_model(trans_state)
                # obs_loss = -obs_dist.log_prob(batch.obs[step_idx])
                # reward_dist = self.reward_model(trans_state)
                # reward_loss = -reward_dist.log_prob(batch.reward[step_idx])
                # trans_loss += obs_loss + reward_loss

            trans_states[steps_ahead:] = next_trans_states
        
        # # "Standard variational bound"
        # trans_loss = 0.0
        # for step_idx in range(seq_len):
        #     trans_h = states[step_idx]
        #     trans_x = batch.act[step_idx].reshape(batch_size, -1)
        #     trans_dist: RSSMStateDist = self.trans_model(trans_h, trans_x)
        #     repr_dist: RSSMStateDist = repr_dists[step_idx+1] # type: ignore
        #     trans_loss += dist.kl_divergence(trans_dist, repr_dist)

        total_loss = recon_loss + trans_loss
        return total_loss

SyntaxError: expected ':' (3375891183.py, line 188)

In [None]:
class CEMPlanner:
    def __init__(self, horizon: int, optim_iters: int, iter_pop: int, iter_top_k: int, act_shape: torch.Size):
        self.horizon = self.H = horizon
        self.optim_iters = self.I = optim_iters
        self.iter_pop = self.J = iter_pop
        self.iter_top_k = self.K = iter_top_k
        self.act_shape = act_shape
    
    def next_action(self, cur_state_dist: RSSMStateDist, rssm: PlaNetRSSM):
        # Initialize action sequence distribution
        mean = torch.zeros((self.H, *self.act_shape))
        std = torch.ones((self.H, *self.act_shape))
        action_seq_dist = dist.Normal(mean, std)

        rewards = torch.zeros(self.J, requires_grad=False)

        for iter in range(self.I):
            # Sample actions from the distribution
            actions = action_seq_dist.sample_n(self.J)

            # Execute actions and predict the rewards
            rewards.zero_()
            state = cur_state_dist.sample_n(self.J)
            rewards += rssm.reward_model(state).mean
            for step_idx in range(self.H):
                state = rssm.trans_model(state, actions[:,step_idx]).sample()
                rewards += rssm.reward_model(state).mean
            
            # Select the best sequences and update the distribution
            best_idxes = torch.argmax(rewards)[:self.K]
            best_actions = actions[best_idxes]
            mean, std = best_actions.mean(0), best_actions.var(0)
            action_seq_dist = dist.Normal(mean, std)
        
        # Return mean first action value
        return mean[0]

In [None]:
from typing import Generic, TypeVar, NamedTuple

T = TypeVar("T")

class stack(Generic[T]):
    def __init__(self):
        self._data: list[T] = []
        self._start = 0
    
    def push(self, x: T):
        self._data.append(x)
        if self._start >= len(self._data):
            self._data = self._data[self._start:]
            self._start = 0
    
    def top(self) -> T:
        return self._data[-1]
    
    def pop(self) -> T:
        item = self._data[self._start]
        self._start += 1
        return item

class TransitionBatch(NamedTuple):
    obs: Tensor
    act: Tensor
    next_obs: Tensor
    reward: Tensor
    done: Tensor


In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size: torch.Size):
        super().__init__()
        self.input_size = input_size
        C, H, W = input_size[-3:]
        self.kernel_size = k = 3
        self.hidden_dim = h = 64
        self.latent_dim = z_dim = 64
        p = k // 2

        conv_size = torch.Size((8*h, H//16, W//16))

        self.main = nn.Sequential(
            nn.Conv2d(C, h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(h, 2*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(2*h, 4*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(4*h, 8*h, k, 1, p),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            MLP([conv_size.numel(), 256, z_dim]),
        )
    
    def forward(self, x: Tensor) -> dist.Distribution:
        params = self.main(x)
        mean, std = torch.split(params, self.latent_dim)
        z_dist = dist.Normal(mean, std)
        z_dist = dist.Independent(z_dist, 1)
        return z_dist

class Decoder(nn.Module):
    def __init__(self, output_size: torch.Size):
        super().__init__()
        self.output_size = output_size
        C, H, W = output_size
        self.hidden_dim = h = 64
        self.latent_dim = z_dim = 64
        self.kernel_size = k = 3
        self.conv_size = torch.Size((8*h, H // 16, W // 16))

        p = k // 2

        self.z_map = MLP([z_dim, 64, self.conv_size.numel()])

        self.main = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(8*h, 4*h, k, 1, p),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(4*h, 2*h, k, 1, p),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(2*h, h, k, 1, p),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(h, C, k, 1, p),
        )
    
    def forward(self, z: Tensor) -> dist.Distribution:
        mean = self.main(self.z_map(z))
        x_dist = dist.Normal(mean, 1.0)
        x_dist = dist.Independent(x_dist, 3)
        return x_dist

class VAE(nn.Module):
    def __init__(self, input_size: torch.Size):
        super().__init__()
        self.enc = Encoder(input_size)
        self.dec = Decoder(input_size)
        
        self.z_prior = dist.Normal(
            loc=torch.zeros(self.enc.latent_dim),
            scale=torch.ones(self.enc.latent_dim),
        )
    
    def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
        zs = self.z_prior.sample(sample_shape)
        x_dists: dist.Distribution = self.dec(zs)
        return x_dists.sample()
    
    def loss(self, x: Tensor) -> Tensor:
        z_dist: dist.Distribution = self.enc(x)
        z = z_dist.rsample()
        x_hat: dist.Distribution = self.dec(z)
        prior_loss = dist.kl_divergence(self.z_prior, z_dist)
        recon_loss = x_hat.log_prob(x)
        total_loss = prior_loss + recon_loss
        return total_loss

In [None]:
from gym.spaces import Space, Discrete, Box

class Actor(nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()


class Critic(nn.Module):
    def __init__(self, obs_shape: torch.Size):
        super().__init__()
        self.vis_enc = VisualEncoder(obs_shape)
        self.fc = MLP([self.vis_enc.out_features, 256, 1])
    
    def forward(self):
        ...

class ObsEncoder(nn.Module):
    def __init__(self, obs_space: Space, enc_dim: int):
        super().__init__()
        if isinstance(obs_space, Box):
            obs_shape = torch.Size(obs_space.shape)
            self.enc = VisualEncoder(obs_shape)
            self.enc = nn.Sequential(self.enc, nn.Linear(self.enc.out_features, enc_dim))
        else:
            raise NotImplementedError()

    def forward(self, obs: Tensor) -> Tensor:
        return self.enc(obs)

class ActionEncoder(nn.Module):
    def __init__(self, act_space: Space, enc_dim: int):
        super().__init__()
        if isinstance(act_space, Discrete):
            self.enc = nn.Embedding(act_space.n, enc_dim)
        elif isinstance(act_space, Box):
            shape = torch.Size(act_space.shape)
            self.enc = MLP([shape.numel(), 256, enc_dim])
        else:
            raise NotImplementedError()
    
    def forward(self, act: Tensor) -> Tensor:
        return self.enc(act)

class ActorCritic(nn.Module):
    def __init__(self, obs_shape: torch.Size, num_actions: int):
        super().__init__()
        self.obs_enc = VisualEncoder(obs_shape)
        self.actor = MLP([self.obs_enc.out_features, 256, num_actions])
        self.critic = MLP([self.obs_enc.out_features + num_actions])