<a href="https://colab.research.google.com/github/seki-shu/Group15_Multimodal/blob/Dreamer-V2/Dreamer_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 準備

## libraryのimport

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

## deviceの確認

In [None]:
# torch.deviceを定義．この変数は後々モデルやデータをGPUに転送する時にも使います
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# モデルの実装


## RSSM

In [None]:
class RSSM(nn.Module):
    def __init__(
        self,
        mlp_hidden_dim: int,
        h_dim: int,
        z_dim: int,
        a_dim: int,
        n_classes: int,
        embedding_dim: int
    )
    super().__init__()

    self.h_dim = h_dim # 決定的状態. stoch
    self.z_dim = z_dim # 確率的状態(カテゴリカル変数). deter
    self.a_dim = a_dim # 行動
    self.n_classes = n_classes # zのカテゴリ数

    # Recurrent model
    # h_t = f(h_t-1, z_t-1, a_t-1)
    self.z_a_hidden = nn.Linear(z_dim * n_classes + a_dim, mlp_hidden_dim)
    self.rnn = nn.GRUCell(mlp_hidden_dim, h_dim)

    # Prior prediction
    # z_t+1_hat = f(h_t+1)
    self.prior_hidden = nn.Linear(h_dim, mlp_hidden_dim)
    self.prior_logits = nn.Linear(mlp_hidden_dim, z_dim * n_classes)

    # Posterior
    # z_t+1 = f(h_t+1, o_t+1)
    self.posterior_hidden = nn.Linear(h_dim + embedding_dim, mlp_hidden_dim) # competition_baselineのほうではembedding_dimがなぜかハードコーディングされてた
    self.posterior_logits = nn.Linear(mlp_hidden_dim, z_dim * n_classes)