<a href="https://colab.research.google.com/github/seki-shu/Group15_Multimodal/blob/Dreamer-V2/Dreamer_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 準備

## libraryのimport

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import torch.distributions as td
from torch.distributions import Normal, Categorical, OneHotCategorical, OneHotCategoricalStraightThrough
from torch.distributions import kl_divergence

## deviceの確認

In [None]:
# torch.deviceを定義．この変数は後々モデルやデータをGPUに転送する時にも使います
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# モデルの実装


## RSSM

In [None]:
class RSSM(nn.Module):
    def __init__(
        self,
        mlp_hidden_dim: int,
        h_dim: int,
        z_dim: int,
        a_dim: int,
        n_classes: int,
        embedding_dim: int
    )
    super().__init__()

    self.h_dim = h_dim # 決定的状態. stoch
    self.z_dim = z_dim # 確率的状態(カテゴリカル変数). deter
    self.a_dim = a_dim # 行動
    self.n_classes = n_classes # zのカテゴリ数

    # Recurrent model
    # h_t = f(h_t-1, z_t-1, a_t-1)
    self.z_a_hidden = nn.Linear(z_dim * n_classes + a_dim, mlp_hidden_dim)
    self.rnn = nn.GRUCell(mlp_hidden_dim, h_dim)

    # Prior prediction
    # z_t+1_hat = f(h_t+1)
    self.prior_hidden = nn.Linear(h_dim, mlp_hidden_dim)
    self.prior_logits = nn.Linear(mlp_hidden_dim, z_dim * n_classes)

    # Posterior
    # z_t+1 = f(h_t+1, o_t+1)
    self.posterior_hidden = nn.Linear(h_dim + embedding_dim, mlp_hidden_dim) # competition_baselineのほうではembedding_dimがなぜかハードコーディングされてた
    self.posterior_logits = nn.Linear(mlp_hidden_dim, z_dim * n_classes)

    def recurrent(
        self,
        h_prev: torch.Tensor,
        z_prev: torch.Tensor,
        a_prev: torch.Tensor,
        rnn_hidden: torch.Tensor
    ):
        """
        h_t = f(h_t-1, z_t-1, a_t-1)
        """
        mlp_hidden = F.elu(self.z_a_hidden(torch.cat([z_prev, a_prev], dim=1)))
        h = self.rnn(mlp_hidden, rnn_hidden)
        return h

    def get_prior(
        self,
        h: torch.Tensor,
        detach = False
    ):
        """
        z_t+1_hat = f(h_t+1)
        """
        mlp_hidden = F.elu(self.prior_hidden(h))
        logits = self.prior_logits(mlp_hidden) # (B, z_dim * n_classes,)
        logits = logits.reshape(logits.shape[0], self.z_dim, self.n_classes) # (B, z_dim, n_classes)

        prior = td.Independent(OneHotCategoricalStraightThrough(logits=logits), 1)
        if detach:
            detached_prior = td.Independent(OneHotCategoricalStraightThrough(logits=logits.detach()), 1)
            return prior, detached_prior
        return prior

    def get_posterior(
        self,
        h: torch.Tensor,
        embedded_obs: torch.Tensor,
        detach = False
    ):
        """
        z_t+1 = f(h_t+1, o_t+1)
        """
        mlp_hidden = F.elu(self.posterior_hidden(torch.cat([h, embedded_obs], dim=1)))
        logits = self.posterior_logits(mlp_hidden) # (B, z_dim * n_classes)
        logits = logits.reshape(logits.shape[0], self.z_dim, self.n_classes) # (B, z_dim, n_classes)

        posterior = td.Independent(OneHotCategoricalStraightThrough(logits=logits), 1)
        if detach:
            detached_posterior = td.Independent(OneHotCategoricalStraightThrough(logits=logits.detach()), 1)
            return posterior, detached_posterior
        return posterior

In [None]:
class Encoder(nn.Module):
    """
    (3, 64, 64) -> (1024, ) にエンコード
    """
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2)

    def forward(
        self,
        obs: torch.Tensor
    ):
        """
        順伝播を行うメソッド．観測画像をベクトルに埋め込む．

        Parameters
        ----------
        obs : torch.Tensor (batch size, 3, 64, 64)
            環境から得られた観測画像．

        Returns
        -------
        embedded_obs : torch.Tensor (batch size, 1024)
            観測を1024次元のベクトルに埋め込んだもの．
        """
        embed = F.elu(self.conv1(obs))
        embed = F.elu(self.conv2(embed))
        embed = F.elu(self.conv3(embed))
        embed = F.elu(self.conv4(embed)).reshape(embed.shape[0], -1)
        return embed