# STORM実装



## 1. 準備  

In [None]:
!pip install gym==0.26.2 gym[atari]==0.26.2 gym[accept-rom-license]==0.26.2 autorom ale-py
# !pip install gym==0.26.2 'gym[atari]==0.26.2' 'gym[accept-rom-license]==0.26.2' 'gym[other]==0.26.2' autorom ale-py
# !pip install einops

### ライブラリインポート  

In [None]:
import time
import os
import gc
import random
from copy import deepcopy
import copy

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import gym
from gym.wrappers import ResizeObservation
import torch
import torch.distributions as distributions
from torch.distributions import Normal, Categorical, OneHotCategorical, OneHotCategoricalStraightThrough
from torch.distributions.kl import kl_divergence
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
import torch.autograd.profiler as prof

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from collections import deque
from tqdm import tqdm



%load_ext tensorboard

In [None]:
# %tensorboard --logdir './logs'

### GPUの確認

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
!nvidia-smi

## 2. 環境の設定  

### Repeat Action    
- モデルによって変更する可能性があると想定している部分は以下のとおりです．
    - 画像のレンダリングサイズ(ResizeObervationクラスのshape)．
    - 同じ行動を繰り返す数（RepeatActionクラスのskip）

In [None]:
class RepeatAction(gym.Wrapper):
    """
    同じ行動を指定された回数自動的に繰り返すラッパー. 観測は最後の行動に対応するものになる
    """
    def __init__(self, env, skip=4, max_steps=100_000):
        gym.Wrapper.__init__(self, env)
        self.max_steps = max_steps if max_steps else float("inf")  # イテレーションの制限
        self.steps = 0  # イテレーション回数のカウント
        self.height = env.observation_space.shape[0]
        self.width = env.observation_space.shape[1]
        self._skip = skip
        self.lives_info = None

    def reset(self):
        obs, info = self.env.reset()
        self.lives_info = info['lives']
        return obs, info

    def step(self, action):
        if self.steps >= self.max_steps:  # 100kに達したら何も返さないようにする
            print("Reached max iterations.")
            return None

        total_reward = 0.0
        self.steps += 1
        for _ in range(self._skip):
            obs, reward, done, _, info = self.env.step(action)

            total_reward += reward
            if self.steps >= self.max_steps:  # 100kに達したら終端にする
                done = True

            if done:
                break

        # ライフが減ったかチェックする
        current_lives_info = info['lives']
        if current_lives_info < self.lives_info:
            info['life_loss'] = True
            self.lives_info = current_lives_info
        else:
            info['life_loss'] = False

        return obs, total_reward, done, info

## 3. 補助機能の実装  


### モデルの保存

In [None]:
# モデルパラメータをGoogleDriveに保存・後で読み込みするためのヘルパークラス
class TrainedModels:
    def __init__(self, *models) -> None:
        """
        コンストラクタ．

        Parameters
        ----------
        models : nn.Module
            保存するモデル．複数モデルを渡すことができます．

        使用例: trained_models = TraindModels(encoder, rssm, value_model, action_model)
        """
        assert np.all([nn.Module in model.__class__.__bases__ for model in models]), "Arguments for TrainedModels need to be nn models."

        self.models = models

    def save(self, dir: str) -> None:
        """
        initで渡したモデルのパラメータを保存します．
        パラメータのファイル名は01.pt, 02.pt, ... のように連番になっています．

        Parameters
        ----------
        dir : str
            パラメータの保存先．
        """
        for i, model in enumerate(self.models):
            torch.save(
                model.state_dict(),
                os.path.join(dir, f"{str(i + 1).zfill(2)}.pt")
            )

    def load(self, dir: str, device: str) -> None:
        """
        initで渡したモデルのパラメータを読み込みます．

        Parameters
        ----------
        dir : str
            パラメータの保存先．
        device : str
            モデルをどのデバイス(CPU or GPU)に載せるかの設定．
        """
        for i, model in enumerate(self.models):
            model.load_state_dict(
                torch.load(
                    os.path.join(dir, f"{str(i + 1).zfill(2)}.pt"),
                    map_location=device
                )
            )

### set seed

In [None]:
def set_seed(seed=20010105) -> None:
    """
    Pytorch, NumPyのシード値を固定します．これによりモデル学習の再現性を担保できます．

    Parameters
    ----------
    seed : int
        シード値．
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## functions

### symlog

In [None]:
@torch.no_grad()
def symlog(x):
    return torch.sign(x) * torch.log(1 + torch.abs(x))

### symexp

In [None]:
@torch.no_grad()
def symexp(x):
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

### SymLogLoss

In [None]:
class SymLogLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, output, target):
        target = symlog(target)
        return 0.5*F.mse_loss(output, target)

### SymLogTwoHotLoss

In [None]:
class SymLogTwoHotLoss(nn.Module):
    def __init__(self, num_classes, lower_bound, upper_bound):
        super().__init__()
        self.num_classes = num_classes
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.bin_length = (upper_bound - lower_bound) / (num_classes-1)

        # use register buffer so that bins move with .cuda() automatically
        self.bins: torch.Tensor
        self.register_buffer(
            'bins', torch.linspace(-20, 20, num_classes), persistent=False)

    def forward(self, output, target):
        target = symlog(target) # (B, L)
        assert target.min() >= self.lower_bound and target.max() <= self.upper_bound

        index = torch.bucketize(target, self.bins)
        diff = target - self.bins[index-1]  # -1 to get the lower bound
        weight = diff / self.bin_length
        weight = torch.clamp(weight, 0, 1)
        weight = weight.unsqueeze(-1) # (1, 1)

        target_prob = (1-weight)*F.one_hot(index-1, self.num_classes) + weight*F.one_hot(index, self.num_classes)

        loss = -target_prob * F.log_softmax(output, dim=-1)
        loss = loss.sum(dim=-1)
        return loss.mean()

    def decode(self, output):
        return symexp(F.softmax(output, dim=-1) @ self.bins)


# if __name__ == "__main__":
#     loss_func = SymLogTwoHotLoss(255, -20, 20)
#     output = torch.randn(1, 1, 255).requires_grad_()
#     target = torch.ones(1).reshape(1, 1).float() * 0.1
#     print(target)
#     loss = loss_func(output, target)
#     print(loss)

    # prob = torch.ones(1, 1, 255)*0.5/255
    # prob[0, 0, 128] = 0.5
    # logits = torch.log(prob)
    # print(loss_func.decode(logits), loss_func.bins[128])

## Attention Blocks

### get subsequent mask

In [None]:
def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''
    batch_size, batch_length = seq.shape[:2]
    subsequent_mask = (1 - torch.triu(
        torch.ones((1, batch_length, batch_length), device=seq.device), diagonal=1)).bool() # (1, L, L)
    return subsequent_mask

### get subsequent mask with batch length

In [None]:
def get_subsequent_mask_with_batch_length(batch_length, device):
    ''' For masking out the subsequent info. '''
    subsequent_mask = (1 - torch.triu(torch.ones((1, batch_length, batch_length), device=device), diagonal=1)).bool()
    return subsequent_mask

### get vector mask

In [None]:
def get_vector_mask(batch_length, device):
    mask = torch.ones((1, 1, batch_length), device=device).bool()
    # mask = torch.ones((1, batch_length, 1), device=device).bool()
    return mask

### Scaled Dot Product Attention

In [None]:
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf')) # -1e9が使われてた

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

### Multi Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q) # (B, L, d_model=512)

        return q, attn

    def forward_with_kv_cache(self, q, k_cache, v_cache, idx):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), q.size(1), q.size(1)
        residual = q
        q = self.w_qs(residual).view(sz_b, 1, n_head, d_k)
        k = self.w_ks(residual).view(sz_b, 1, n_head, d_k)
        v = self.w_vs(residual).view(sz_b, 1, n_head, d_v)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        k_cache[idx] = torch.cat([k_cache[idx], k], dim=2)
        v_cache[idx] = torch.cat([v_cache[idx], v], dim=2)
        q, attn = self.attention(q, k_cache[idx], v_cache[idx])
        q = q.transpose(1, 2).contiguous().view(sz_b, 1, -1)
        q = self.dropout(self.fc(q))
        q += residual
        q = self.layer_norm(q)
        return q

### Positionwise Feed Forward

In [None]:
class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)  # position-wise
        self.w_2 = nn.Linear(d_hid, d_in)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

### Attention Block KV Cache

In [None]:
class AttentionBlockKVCache(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_heads, dropout):
        super().__init__()
        self.slf_attn = MultiHeadAttention(num_heads, feat_dim, feat_dim//num_heads, feat_dim//num_heads, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(feat_dim, hidden_dim, dropout=dropout)

    def forward(self, q, k, v, idx=None, slf_attn_mask=None, kv_cache=False):
        if kv_cache:
            output = self.slf_attn.forward_with_kv_cache(q, k, v, idx)
            output = self.pos_ffn(output)
            return output
        output, attn = self.slf_attn(q, k, v, mask=slf_attn_mask)
        output = self.pos_ffn(output)
        return output, attn

### Positional Encoding 1D

In [None]:
class PositionalEncoding1D(nn.Module):
    def __init__(
        self,
        max_length: int,
        embed_dim: int
    ):
        super().__init__()
        self.max_length = max_length
        self.embed_dim = embed_dim

        self.pos_emb = nn.Embedding(self.max_length, embed_dim)

    def forward(self, feat):
        pos_emb = self.pos_emb(torch.arange(self.max_length, device=feat.device))
        pos_emb = repeat(pos_emb, "L D -> B L D", B=feat.shape[0])

        feat = feat + pos_emb[:, :feat.shape[1], :]
        return feat

    def forward_with_position(self, feat, position):
        assert feat.shape[1] == 1
        pos_emb = self.pos_emb(torch.arange(self.max_length, device=feat.device))
        pos_emb = repeat(pos_emb, "L D -> B L D", B=feat.shape[0])

        feat = feat + pos_emb[:, position:position+1, :]
        return feat

## Transformer Blocks

### Transformer KV cache


In [None]:
class StochasticTransformerKVCache(nn.Module):
    def __init__(self, stoch_dim, action_dim, feat_dim, num_layers, num_heads, max_length, dropout):
        super().__init__()
        self.action_dim = action_dim
        self.feat_dim = feat_dim
        self.num_heads = num_heads
        self.head_dim = feat_dim // num_heads

        # mix image_embedding and action
        self.stem = nn.Sequential(
            nn.Linear(stoch_dim+action_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim)
        )
        self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim)
        self.layer_stack = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)

    def forward(self, samples, action, mask):
        '''
        Normal forward pass
        mask: (1, L, L)
        '''
        action = F.one_hot(action.long(), self.action_dim).float()
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats)

        for layer in self.layer_stack:
            feats, attn = layer(feats, feats, feats, slf_attn_mask=mask)

        return feats

    # def reset_kv_cache_list(self, batch_size, dtype):
    #     '''
    #     Reset self.kv_cache_list
    #     '''
    #     self.kv_cache_list = []
    #     for layer in self.layer_stack:
    #         self.kv_cache_list.append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device="cuda"))

    def reset_kv_cache_list(self, batch_size, dtype):
        self.key_cache_list = []
        self.value_cache_list = []
        for layer in self.layer_stack:
            self.key_cache_list.append(torch.zeros(size=(batch_size, self.num_heads, 0, self.head_dim), dtype=dtype, device="cuda"))
            self.value_cache_list.append(torch.zeros(size=(batch_size, self.num_heads, 0, self.head_dim), dtype=dtype, device="cuda"))

    def forward_with_kv_cache(self, samples, action):
        '''
        Forward pass with kv_cache, cache stored in self.kv_cache_list
        samples: last_flattened_sample (B, 1, n_classes*stoch_dim)
        '''
        assert samples.shape[1] == 1

        action = F.one_hot(action.long(), self.action_dim).float()
        feats = self.stem(torch.cat([samples, action], dim=-1)) # mix image_embedding and action # (B, 1, feat_dim=512)
        feats = self.position_encoding.forward_with_position(feats, position=self.key_cache_list[0].shape[1]) # (B, 1, feat_dim)
        feats = self.layer_norm(feats) # (B, 1, feat_dim=512)

        for idx, layer in enumerate(self.layer_stack):
            feats = layer(feats, self.key_cache_list, self.value_cache_list, idx, kv_cache=True)

        return feats

## Agents

### percentile

In [None]:
def percentile(x, percentage):
    flat_x = torch.flatten(x)
    kth = int(percentage*len(flat_x))
    per = torch.kthvalue(flat_x, kth).values
    return per

### calculate λ return

In [None]:
def calc_lambda_return(rewards, values, termination, gamma, lam, dtype=torch.float32):
    # Invert termination to have 0 if the episode ended and 1 otherwise
    inv_termination = (termination * -1) + 1

    batch_size, batch_length = rewards.shape[:2]
    # gae_step = torch.zeros((batch_size, ), dtype=dtype, device="cuda")
    gamma_return = torch.zeros((batch_size, batch_length+1), dtype=dtype, device="cuda")
    gamma_return[:, -1] = values[:, -1]
    for t in reversed(range(batch_length)):  # with last bootstrap
        gamma_return[:, t] = \
            rewards[:, t] + \
            gamma * inv_termination[:, t] * (1-lam) * values[:, t] + \
            gamma * inv_termination[:, t] * lam * gamma_return[:, t+1]
    return gamma_return[:, :-1]

### Actor Critic Agent

In [None]:
class ActorCriticAgent(nn.Module):
    def __init__(self, feat_dim, num_layers, hidden_dim, action_dim, gamma, lambd, entropy_coef) -> None:
        super().__init__()
        self.gamma = gamma
        self.lambd = lambd
        self.entropy_coef = entropy_coef
        self.use_amp = True
        self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32

        self.symlog_twohot_loss = SymLogTwoHotLoss(255, -20, 20)

        actor = [
            nn.Linear(feat_dim, hidden_dim, bias=False),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        ]
        for i in range(num_layers - 1):
            actor.extend([
                nn.Linear(hidden_dim, hidden_dim, bias=False),
                nn.LayerNorm(hidden_dim),
                nn.ReLU()
            ])
        self.actor = nn.Sequential(
            *actor,
            nn.Linear(hidden_dim, action_dim)
        )

        critic = [
            nn.Linear(feat_dim, hidden_dim, bias=False),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        ]
        for i in range(num_layers - 1):
            critic.extend([
                nn.Linear(hidden_dim, hidden_dim, bias=False),
                nn.LayerNorm(hidden_dim),
                nn.ReLU()
            ])

        self.critic = nn.Sequential(
            *critic,
            nn.Linear(hidden_dim, 255)
        )
        self.slow_critic = copy.deepcopy(self.critic)

        self.lowerbound_ema = EMAScalar(decay=0.99)
        self.upperbound_ema = EMAScalar(decay=0.99)

        self.optimizer = torch.optim.AdamW(self.parameters(), lr=4e-5, eps=1e-5)
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    @torch.no_grad()
    def update_slow_critic(self, decay=0.98):
        for slow_param, param in zip(self.slow_critic.parameters(), self.critic.parameters()):
            slow_param.data.copy_(slow_param.data * decay + param.data * (1 - decay))

    def policy(self, x):
        logits = self.actor(x)
        return logits

    def value(self, x):
        value = self.critic(x)
        value = self.symlog_twohot_loss.decode(value)
        return value

    @torch.no_grad()
    def slow_value(self, x):
        value = self.slow_critic(x)
        value = self.symlog_twohot_loss.decode(value)
        return value

    def get_logits_raw_value(self, x):
        logits = self.actor(x)
        raw_value = self.critic(x)
        return logits, raw_value

    @torch.no_grad()
    def sample(self, latent, greedy=False):
        # latentはpriorとtransformerのhをconcatしたもの
        self.eval()
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.use_amp):
            logits = self.policy(latent)
            dist = distributions.Categorical(logits=logits)
            if greedy:
                action = dist.probs.argmax(dim=-1)
            else:
                action = dist.sample()
        return action

    def sample_as_env_action(self, latent, greedy=False):
        action = self.sample(latent, greedy)
        return action.detach().cpu().squeeze(-1).item()

    def update(self, latent, action, old_logprob, old_value, reward, termination, logger=None):
        '''
        Update policy and value model
        '''
        self.train()
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.use_amp):
            logits, raw_value = self.get_logits_raw_value(latent) # (B, L+1, action_dim) # (B, L+1, 255)
            dist = distributions.Categorical(logits=logits[:, :-1]) # (B, L, action_dim)
            log_prob = dist.log_prob(action) # (B, L)
            entropy = dist.entropy() # (B, L)

            # decode value, calc lambda return
            slow_value = self.slow_value(latent) # (B, L+1) # 勾配なし
            slow_lambda_return = calc_lambda_return(reward, slow_value, termination, self.gamma, self.lambd) # (B, L)
            value = self.symlog_twohot_loss.decode(raw_value) # (B, L+1)
            lambda_return = calc_lambda_return(reward, value, termination, self.gamma, self.lambd) # (B, L)

            # update value function with slow critic regularization
            value_loss = self.symlog_twohot_loss(raw_value[:, :-1], lambda_return.detach()) #
            slow_value_regularization_loss = self.symlog_twohot_loss(raw_value[:, :-1], slow_lambda_return.detach())

            lower_bound = self.lowerbound_ema(percentile(lambda_return, 0.05))
            upper_bound = self.upperbound_ema(percentile(lambda_return, 0.95))
            S = upper_bound-lower_bound
            norm_ratio = torch.max(torch.ones(1).cuda(), S)  # max(1, S) in the paper
            norm_advantage = (lambda_return-value[:, :-1]) / norm_ratio
            policy_loss = -(log_prob * norm_advantage.detach()).mean()

            entropy_loss = entropy.mean()

            loss = policy_loss + value_loss + slow_value_regularization_loss - self.entropy_coef * entropy_loss

        # gradient descent
        self.scaler.scale(loss).backward()
        self.scaler.unscale_(self.optimizer)  # for clip grad
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=100.0)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad(set_to_none=True)

        self.update_slow_critic()


        if logger is not None:
            logger.log('ActorCritic/policy_loss', policy_loss.item())
            logger.log('ActorCritic/value_loss', value_loss.item())
            logger.log('ActorCritic/entropy_loss', - entropy_loss.item())
            logger.log('ActorCritic/S', S.item())
            logger.log('ActorCritic/norm_ratio', norm_ratio.item())
            logger.log('ActorCritic/slow_value_reg', slow_value_regularization_loss.item())
            logger.log('ActorCritic/total_loss', loss.item())




## World Model


### Encoder

In [None]:
class EncoderBN(nn.Module):
    def __init__(self, in_channels, stem_channels, final_feature_width) -> None:
        super().__init__()

        backbone = []
        # stem
        backbone.append(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=stem_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            )
        )
        feature_width = 64//2
        channels = stem_channels
        backbone.append(nn.BatchNorm2d(stem_channels))
        backbone.append(nn.ReLU(inplace=True))

        # layers
        while True:
            backbone.append(
                nn.Conv2d(
                    in_channels=channels,
                    out_channels=channels*2,
                    kernel_size=4,
                    stride=2,
                    padding=1,
                    bias=False
                )
            )
            channels *= 2
            feature_width //= 2
            backbone.append(nn.BatchNorm2d(channels))
            backbone.append(nn.ReLU(inplace=True))

            if feature_width == final_feature_width:
                break

        self.backbone = nn.Sequential(*backbone)
        self.last_channels = channels

    def forward(self, x):
        batch_size = x.shape[0]
        x = rearrange(x, "B L C H W -> (B L) C H W")
        x = self.backbone(x)
        # start_time = time.time()
        # for i in range(len(self.backbone)):
        #     print(self.backbone[i])
        #     x = self.backbone[i](x)  # Conv2d
        #     end_time = time.time()
        #     print(f"time: {end_time - start_time:.4f} s")
        #     start_time = time.time()
        #     print(f'device:{x.device}, type:{x.dtype}')
        # time.time()
        # print('--------------------')
        x = rearrange(x, "(B L) C H W -> B L (C H W)", B=batch_size)
        return x

### Decoder

In [None]:
class DecoderBN(nn.Module):
    def __init__(self, stoch_dim, last_channels, original_in_channels, stem_channels, final_feature_width) -> None:
        super().__init__()

        backbone = []
        # stem
        backbone.append(nn.Linear(stoch_dim, last_channels*final_feature_width*final_feature_width, bias=False))
        backbone.append(Rearrange('B L (C H W) -> (B L) C H W', C=last_channels, H=final_feature_width))
        backbone.append(nn.BatchNorm2d(last_channels))
        backbone.append(nn.ReLU(inplace=True))
        # residual_layer
        # backbone.append(ResidualStack(last_channels, 1, last_channels//4))
        # layers
        channels = last_channels
        feat_width = final_feature_width
        while True:
            if channels == stem_channels:
                break
            backbone.append(
                nn.ConvTranspose2d(
                    in_channels=channels,
                    out_channels=channels//2,
                    kernel_size=4,
                    stride=2,
                    padding=1,
                    bias=False
                )
            )
            channels //= 2
            feat_width *= 2
            backbone.append(nn.BatchNorm2d(channels))
            backbone.append(nn.ReLU(inplace=True))

        backbone.append(
            nn.ConvTranspose2d(
                in_channels=channels,
                out_channels=original_in_channels,
                kernel_size=4,
                stride=2,
                padding=1
            )
        )
        self.backbone = nn.Sequential(*backbone)

    def forward(self, sample):
        batch_size = sample.shape[0]
        obs_hat = self.backbone(sample)
        obs_hat = rearrange(obs_hat, "(B L) C H W -> B L C H W", B=batch_size)
        return obs_hat

### Dist Head

In [None]:
class DistHead(nn.Module):
    '''
    Dist: abbreviation of distribution
    '''
    def __init__(self, image_feat_dim, transformer_hidden_dim, stoch_dim) -> None:
        super().__init__()
        self.stoch_dim = stoch_dim
        self.post_head = nn.Linear(image_feat_dim, stoch_dim*stoch_dim)
        self.prior_head = nn.Linear(transformer_hidden_dim, stoch_dim*stoch_dim)

    def unimix(self, logits, mixing_ratio=0.01):
        # uniform noise mixing
        probs = F.softmax(logits, dim=-1)
        mixed_probs = mixing_ratio * torch.ones_like(probs) / self.stoch_dim + (1-mixing_ratio) * probs
        logits = torch.log(mixed_probs)
        return logits

    def forward_post(self, x):
        logits = self.post_head(x)
        logits = rearrange(logits, "B L (K C) -> B L K C", K=self.stoch_dim)
        logits = self.unimix(logits)
        return logits

    def forward_prior(self, x):
        logits = self.prior_head(x)
        logits = rearrange(logits, "B L (K C) -> B L K C", K=self.stoch_dim)
        logits = self.unimix(logits)
        return logits

### Reward Decoder

In [None]:
class RewardDecoder(nn.Module):
    def __init__(self, num_classes, embedding_size, transformer_hidden_dim) -> None:
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False),
            nn.LayerNorm(transformer_hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False),
            nn.LayerNorm(transformer_hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.head = nn.Linear(transformer_hidden_dim, num_classes)

    def forward(self, feat):
        feat = self.backbone(feat)
        reward = self.head(feat)
        return reward

### Termination Decoder

In [None]:
class TerminationDecoder(nn.Module):
    def __init__(self,  embedding_size, transformer_hidden_dim) -> None:
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False),
            nn.LayerNorm(transformer_hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(transformer_hidden_dim, transformer_hidden_dim, bias=False),
            nn.LayerNorm(transformer_hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.head = nn.Sequential(
            nn.Linear(transformer_hidden_dim, 1),
            # nn.Sigmoid()
        )

    def forward(self, feat):
        feat = self.backbone(feat)
        termination = self.head(feat)
        termination = termination.squeeze(-1)  # remove last 1 dim
        return termination

### MSELoss

In [None]:
class MSELoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, obs_hat, obs):
        loss = (obs_hat - obs)**2
        loss = reduce(loss, "B L C H W -> B L", "sum")
        return loss.mean()

### CategoricalKLDivLossWithFreeBits


In [None]:
class CategoricalKLDivLossWithFreeBits(nn.Module):
    def __init__(self, free_bits) -> None:
        super().__init__()
        self.free_bits = free_bits

    def forward(self, p_logits, q_logits):
        p_dist = OneHotCategorical(logits=p_logits)
        q_dist = OneHotCategorical(logits=q_logits)
        kl_div = torch.distributions.kl.kl_divergence(p_dist, q_dist)
        kl_div = reduce(kl_div, "B L D -> B L", "sum")
        kl_div = kl_div.mean()
        real_kl_div = kl_div
        kl_div = torch.max(torch.ones_like(kl_div)*self.free_bits, kl_div)
        return kl_div, real_kl_div

### World Model

In [None]:
class WorldModel(nn.Module):
    def __init__(self, in_channels, action_dim,
                 transformer_max_length, transformer_hidden_dim, transformer_num_layers, transformer_num_heads):
        super().__init__()
        self.transformer_hidden_dim = transformer_hidden_dim
        self.final_feature_width = 4
        self.stoch_dim = 32
        self.stoch_flattened_dim = self.stoch_dim*self.stoch_dim
        self.use_amp = True
        self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32
        self.imagine_batch_size = -1
        self.imagine_batch_length = -1

        self.encoder = EncoderBN(
            in_channels=in_channels,
            stem_channels=32,
            final_feature_width=self.final_feature_width
        )
        self.storm_transformer = StochasticTransformerKVCache(
            stoch_dim=self.stoch_flattened_dim,
            action_dim=action_dim,
            feat_dim=transformer_hidden_dim,
            num_layers=transformer_num_layers,
            num_heads=transformer_num_heads,
            max_length=transformer_max_length,
            dropout=0.1
        )
        self.dist_head = DistHead(
            image_feat_dim=self.encoder.last_channels*self.final_feature_width*self.final_feature_width,
            transformer_hidden_dim=transformer_hidden_dim,
            stoch_dim=self.stoch_dim
        )
        self.image_decoder = DecoderBN(
            stoch_dim=self.stoch_flattened_dim,
            last_channels=self.encoder.last_channels,
            original_in_channels=in_channels,
            stem_channels=32,
            final_feature_width=self.final_feature_width
        )
        self.reward_decoder = RewardDecoder(
            num_classes=255,
            embedding_size=self.stoch_flattened_dim,
            transformer_hidden_dim=transformer_hidden_dim
        )
        self.termination_decoder = TerminationDecoder(
            embedding_size=self.stoch_flattened_dim,
            transformer_hidden_dim=transformer_hidden_dim
        )

        self.mse_loss_func = MSELoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.bce_with_logits_loss_func = nn.BCEWithLogitsLoss()
        self.symlog_twohot_loss_func = SymLogTwoHotLoss(num_classes=255, lower_bound=-20, upper_bound=20)
        self.categorical_kl_div_loss = CategoricalKLDivLossWithFreeBits(free_bits=1)
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    def encode_obs(self, obs):
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.use_amp):
            embedding = self.encoder(obs) # (B, L, image_feat_dim)
            post_logits = self.dist_head.forward_post(embedding) # (B, L, n_classes, stoch_dim)
            sample = self.stright_throught_gradient(post_logits, sample_mode="random_sample") # (B, L, n_classes, stoch_dim) # one-hot
            flattened_sample = self.flatten_sample(sample) # (B, L, n_classes*stoch_dim)
        return flattened_sample

    def calc_last_dist_feat(self, latent, action):
        # 事前分布のサンプルと、Transformerからの最後の出力得る
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.use_amp):
            temporal_mask = get_subsequent_mask(latent) # (1, L, L) # 下三角行列
            dist_feat = self.storm_transformer(latent, action, temporal_mask) # (B, L, n_heads*d_v)
            last_dist_feat = dist_feat[:, -1:] # (B, 1, n_heads*d_v)
            prior_logits = self.dist_head.forward_prior(last_dist_feat) # (B, 1, n_classes, stoch_dim)
            prior_sample = self.stright_throught_gradient(prior_logits, sample_mode="random_sample") # (B, 1, n_classes, stoch_dim) # one-hot
            prior_flattened_sample = self.flatten_sample(prior_sample) # (B, 1, n_classes*stoch_dim)
        return prior_flattened_sample, last_dist_feat # (B, 1, n_classes*stoch_dim) # (B, 1, n_heads*d_v)

    def predict_next(self, last_flattened_sample, action, log_video=True):
        '''
        last_flattened_sample: (B, 1, n_classes*stoch_dim)
        action: (B, 1)
        '''
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.use_amp):
            dist_feat = self.storm_transformer.forward_with_kv_cache(last_flattened_sample, action) # (B, 1, 512) # Transformerからの出力h
            prior_logits = self.dist_head.forward_prior(dist_feat) # (B, 1, n_classes, stoch_dim)

            # decoding
            prior_sample = self.stright_throught_gradient(prior_logits, sample_mode="random_sample") # (B, 1, n_classes, stoch_dim) # one-hot
            prior_flattened_sample = self.flatten_sample(prior_sample) # (B, 1, n_classes*stoch_dim)
            if log_video:
                obs_hat = self.image_decoder(prior_flattened_sample)
            else:
                obs_hat = None
            reward_hat = self.reward_decoder(dist_feat) # (B, 1, num_classes=255)
            reward_hat = self.symlog_twohot_loss_func.decode(reward_hat) # (1024, 1)
            termination_hat = self.termination_decoder(dist_feat) # (1024, 1)
            termination_hat = termination_hat > 0
        return obs_hat, reward_hat, termination_hat, prior_flattened_sample, dist_feat

    def stright_throught_gradient(self, logits, sample_mode="random_sample"):
        dist = OneHotCategorical(logits=logits)
        if sample_mode == "random_sample":
            sample = dist.sample() + dist.probs - dist.probs.detach()
        elif sample_mode == "mode":
            sample = dist.mode
        elif sample_mode == "probs":
            sample = dist.probs
        return sample

    def flatten_sample(self, sample):
        return rearrange(sample, "B L K C -> B L (K C)")

    def init_imagine_buffer(self, imagine_batch_size, imagine_batch_length, dtype):
        '''
        This can slightly improve the efficiency of imagine_data
        But may vary across different machines
        '''
        if self.imagine_batch_size != imagine_batch_size or self.imagine_batch_length != imagine_batch_length:
            print(f"init_imagine_buffer: {imagine_batch_size}x{imagine_batch_length}@{dtype}") # 1024 * 16
            self.imagine_batch_size = imagine_batch_size
            self.imagine_batch_length = imagine_batch_length
            latent_size = (imagine_batch_size, imagine_batch_length+1, self.stoch_flattened_dim)
            hidden_size = (imagine_batch_size, imagine_batch_length+1, self.transformer_hidden_dim)
            scalar_size = (imagine_batch_size, imagine_batch_length)
            self.latent_buffer = torch.zeros(latent_size, dtype=dtype, device="cuda")
            self.hidden_buffer = torch.zeros(hidden_size, dtype=dtype, device="cuda")
            self.action_buffer = torch.zeros(scalar_size, dtype=dtype, device="cuda")
            self.reward_hat_buffer = torch.zeros(scalar_size, dtype=dtype, device="cuda")
            self.termination_hat_buffer = torch.zeros(scalar_size, dtype=dtype, device="cuda")

    def imagine_data(self, agent: ActorCriticAgent, sample_obs, sample_action,
                     imagine_batch_size, imagine_batch_length, log_video, logger):
        self.init_imagine_buffer(imagine_batch_size, imagine_batch_length, dtype=self.tensor_dtype)
        obs_hat_list = []

        self.storm_transformer.reset_kv_cache_list(imagine_batch_size, dtype=self.tensor_dtype)
        # context imagineを開始する前の準備
        # print('imagine dataで呼び出し', sample_obs.shape)
        # start = time.time()
        context_latent = self.encode_obs(sample_obs) # (B, L, n_classes*stoch_dim)
        # end = time.time()
        # print(f"encode_obs time: {end - start:.4f} s")
        # start = time.time()
        for i in range(sample_obs.shape[1]):  # context_length is sample_obs.shape[1]
            last_obs_hat, last_reward_hat, last_termination_hat, last_latent, last_dist_feat = self.predict_next(
                context_latent[:, i:i+1],
                sample_action[:, i:i+1],
                log_video=log_video
            )
        # end = time.time()
        # print('get context', end - start)
        self.latent_buffer[:, 0:1] = last_latent # (B, 1, n_classes*stoch_dim) # 事前分布のsample(one-hot)をflattenしたもの
        self.hidden_buffer[:, 0:1] = last_dist_feat # (B, 1, 512) # Transformerからの出力h

        # imagine
        # start = time.time()
        for i in range(imagine_batch_length):
            action = agent.sample(torch.cat([self.latent_buffer[:, i:i+1], self.hidden_buffer[:, i:i+1]], dim=-1))
            self.action_buffer[:, i:i+1] = action # action: (B, 1)

            last_obs_hat, last_reward_hat, last_termination_hat, last_latent, last_dist_feat = self.predict_next(
                self.latent_buffer[:, i:i+1], self.action_buffer[:, i:i+1], log_video=log_video)

            self.latent_buffer[:, i+1:i+2] = last_latent
            self.hidden_buffer[:, i+1:i+2] = last_dist_feat
            self.reward_hat_buffer[:, i:i+1] = last_reward_hat
            self.termination_hat_buffer[:, i:i+1] = last_termination_hat
            if log_video:
                obs_hat_list.append(last_obs_hat[::imagine_batch_size//16])  # uniform sample vec_env
        # end = time.time()
        # print('imagine time', end - start)
        if log_video:
            logger.log("Imagine/predict_video", torch.clamp(torch.cat(obs_hat_list, dim=1), 0, 1).cpu().float().detach().numpy())

        return torch.cat([self.latent_buffer, self.hidden_buffer], dim=-1), self.action_buffer, self.reward_hat_buffer, self.termination_hat_buffer

    def update(self, obs, action, reward, termination, logger=None):
        self.train()
        batch_size, batch_length = obs.shape[:2]

        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.use_amp):
            # encoding
            # print('updateで呼び出し', obs.shape)
            embedding = self.encoder(obs) # (B, L, image_feat_dim) # image_feat_dimはCNN通した後のCHWをflattenした次元
            post_logits = self.dist_head.forward_post(embedding) # (B, L, n_classes, stoch_dim)
            sample = self.stright_throught_gradient(post_logits, sample_mode="random_sample") # (B, L, n_classes, stoch_dim)
            flattened_sample = self.flatten_sample(sample) # (B, L, n_classes*stoch_dim) # 事後分布のflattenされたサンプル

            # decoding image
            obs_hat = self.image_decoder(flattened_sample) # (B, L, C, H, W)

            # transformer
            temporal_mask = get_subsequent_mask_with_batch_length(batch_length, flattened_sample.device) # (1, L, L)
            dist_feat = self.storm_transformer(flattened_sample, action, temporal_mask) # (B, L, n*d_v)
            prior_logits = self.dist_head.forward_prior(dist_feat) # (B, L, n_classes, stoch_dim)
            # decoding reward and termination with dist_feat
            reward_hat = self.reward_decoder(dist_feat) # (B, L, num_classes)
            termination_hat = self.termination_decoder(dist_feat) # (B, L, 1)

            # env loss
            reconstruction_loss = self.mse_loss_func(obs_hat, obs)
            reward_loss = self.symlog_twohot_loss_func(reward_hat, reward)
            termination_loss = self.bce_with_logits_loss_func(termination_hat, termination)
            # dyn-rep loss
            dynamics_loss, dynamics_real_kl_div = self.categorical_kl_div_loss(post_logits[:, 1:].detach(), prior_logits[:, :-1])
            representation_loss, representation_real_kl_div = self.categorical_kl_div_loss(post_logits[:, 1:], prior_logits[:, :-1].detach())

            total_loss = reconstruction_loss + reward_loss + termination_loss + 0.5*dynamics_loss + 0.1*representation_loss

        # gradient descent
        self.scaler.scale(total_loss).backward()
        self.scaler.unscale_(self.optimizer)  # for clip grad
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1000.0)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad(set_to_none=True)

        if logger is not None:
            logger.log("WorldModel/reconstruction_loss", reconstruction_loss.item())
            logger.log("WorldModel/reward_loss", reward_loss.item())
            logger.log("WorldModel/termination_loss", termination_loss.item())
            logger.log("WorldModel/dynamics_loss", dynamics_loss.item())
            logger.log("WorldModel/dynamics_real_kl_div", dynamics_real_kl_div.item())
            logger.log("WorldModel/representation_loss", representation_loss.item())
            logger.log("WorldModel/representation_real_kl_div", representation_real_kl_div.item())
            logger.log("WorldModel/total_loss", total_loss.item())

## その他の機能

### ReplayBuffer

In [None]:
class ReplayBuffer():
    def __init__(self, obs_shape, num_envs, max_length=int(1E6), warmup_length=50000, store_on_gpu=False) -> None:
        self.store_on_gpu = store_on_gpu
        if store_on_gpu:
            self.obs_buffer = torch.empty((max_length//num_envs, num_envs, *obs_shape), dtype=torch.uint8, device="cuda", requires_grad=False)
            self.action_buffer = torch.empty((max_length//num_envs, num_envs), dtype=torch.uint8, device="cuda", requires_grad=False)
            self.reward_buffer = torch.empty((max_length//num_envs, num_envs), dtype=torch.float32, device="cuda", requires_grad=False)
            self.termination_buffer = torch.empty((max_length//num_envs, num_envs), dtype=torch.float32, device="cuda", requires_grad=False)
        else:
            self.obs_buffer = np.empty((max_length//num_envs, num_envs, *obs_shape), dtype=np.uint8)
            self.action_buffer = np.empty((max_length//num_envs, num_envs), dtype=np.float32)
            self.reward_buffer = np.empty((max_length//num_envs, num_envs), dtype=np.float32)
            self.termination_buffer = np.empty((max_length//num_envs, num_envs), dtype=np.float32)

        self.length = 0
        self.num_envs = num_envs
        self.last_pointer = -1
        self.max_length = max_length
        self.warmup_length = warmup_length
        self.external_buffer_length = None

    def load_trajectory(self, path):
        buffer = pickle.load(open(path, "rb"))
        if self.store_on_gpu:
            self.external_buffer = {name: torch.from_numpy(buffer[name]).to("cuda") for name in buffer}
        else:
            self.external_buffer = buffer
        self.external_buffer_length = self.external_buffer["obs"].shape[0]

    def sample_external(self, batch_size, batch_length, to_device="cuda"):
        indexes = np.random.randint(0, self.external_buffer_length+1-batch_length, size=batch_size)
        if self.store_on_gpu:
            obs = torch.stack([self.external_buffer["obs"][idx:idx+batch_length] for idx in indexes])
            action = torch.stack([self.external_buffer["action"][idx:idx+batch_length] for idx in indexes])
            reward = torch.stack([self.external_buffer["reward"][idx:idx+batch_length] for idx in indexes])
            termination = torch.stack([self.external_buffer["done"][idx:idx+batch_length] for idx in indexes])
        else:
            obs = np.stack([self.external_buffer["obs"][idx:idx+batch_length] for idx in indexes])
            action = np.stack([self.external_buffer["action"][idx:idx+batch_length] for idx in indexes])
            reward = np.stack([self.external_buffer["reward"][idx:idx+batch_length] for idx in indexes])
            termination = np.stack([self.external_buffer["done"][idx:idx+batch_length] for idx in indexes])
        return obs, action, reward, termination

    def ready(self):
        return self.length * self.num_envs > self.warmup_length

    @torch.no_grad()
    def sample(self, batch_size, external_batch_size, batch_length, to_device="cuda"):
        if self.store_on_gpu:
            obs, action, reward, termination = [], [], [], []
            if batch_size > 0:
                for i in range(self.num_envs):
                    indexes = np.random.randint(0, self.length+1-batch_length, size=batch_size//self.num_envs)
                    obs.append(torch.stack([self.obs_buffer[idx:idx+batch_length, i] for idx in indexes]))
                    action.append(torch.stack([self.action_buffer[idx:idx+batch_length, i] for idx in indexes]))
                    reward.append(torch.stack([self.reward_buffer[idx:idx+batch_length, i] for idx in indexes]))
                    termination.append(torch.stack([self.termination_buffer[idx:idx+batch_length, i] for idx in indexes]))

            if self.external_buffer_length is not None and external_batch_size > 0:
                external_obs, external_action, external_reward, external_termination = self.sample_external(
                    external_batch_size, batch_length, to_device)
                obs.append(external_obs)
                action.append(external_action)
                reward.append(external_reward)
                termination.append(external_termination)

            obs = torch.cat(obs, dim=0).float() / 255
            obs = rearrange(obs, "B T H W C -> B T C H W") # Bのところは実際にはbatch_length * num_envsの長さかな？
            action = torch.cat(action, dim=0)
            reward = torch.cat(reward, dim=0)
            termination = torch.cat(termination, dim=0)
        else:
            obs, action, reward, termination = [], [], [], []
            if batch_size > 0:
                for i in range(self.num_envs):
                    indexes = np.random.randint(0, self.length+1-batch_length, size=batch_size//self.num_envs)
                    obs.append(np.stack([self.obs_buffer[idx:idx+batch_length, i] for idx in indexes]))
                    action.append(np.stack([self.action_buffer[idx:idx+batch_length, i] for idx in indexes]))
                    reward.append(np.stack([self.reward_buffer[idx:idx+batch_length, i] for idx in indexes]))
                    termination.append(np.stack([self.termination_buffer[idx:idx+batch_length, i] for idx in indexes]))

            if self.external_buffer_length is not None and external_batch_size > 0:
                external_obs, external_action, external_reward, external_termination = self.sample_external(
                    external_batch_size, batch_length, to_device)
                obs.append(external_obs)
                action.append(external_action)
                reward.append(external_reward)
                termination.append(external_termination)

            obs = torch.from_numpy(np.concatenate(obs, axis=0)).float().cuda() / 255
            obs = rearrange(obs, "B T H W C -> B T C H W")
            action = torch.from_numpy(np.concatenate(action, axis=0)).cuda()
            reward = torch.from_numpy(np.concatenate(reward, axis=0)).cuda()
            termination = torch.from_numpy(np.concatenate(termination, axis=0)).cuda()

        return obs, action, reward, termination

    def append(self, obs, action, reward, termination):
        # obs/nex_obs: torch Tensor
        # action/reward/termination: int or float or bool
        self.last_pointer = (self.last_pointer + 1) % (self.max_length//self.num_envs)
        if self.store_on_gpu:
            self.obs_buffer[self.last_pointer] = torch.from_numpy(obs)
            self.action_buffer[self.last_pointer] = torch.tensor(action, dtype=torch.uint8)
            self.reward_buffer[self.last_pointer] = torch.tensor(reward, dtype=torch.float32)
            self.termination_buffer[self.last_pointer] = torch.from_numpy(np.array(termination))
        else:
            self.obs_buffer[self.last_pointer] = obs
            self.action_buffer[self.last_pointer] = action
            self.reward_buffer[self.last_pointer] = reward
            self.termination_buffer[self.last_pointer] = termination

        if len(self) < self.max_length:
            self.length += 1

    def __len__(self):
        return self.length * self.num_envs

## Utils

### EMA Scalar

In [None]:
class EMAScalar():
    def __init__(self, decay) -> None:
        self.scalar = 0.0
        self.decay = decay

    def __call__(self, value):
        self.update(value)
        return self.get()

    def update(self, value):
        self.scalar = self.scalar * self.decay + value * (1 - self.decay)

    def get(self):
        return self.scalar

### Logger

In [None]:
class Logger():
    def __init__(self, path) -> None:
        self.writer = SummaryWriter(path, flush_secs=1)
        self.tag_step = {}

    def log(self, tag, value):
        if tag not in self.tag_step:
            self.tag_step[tag] = 0
        else:
            self.tag_step[tag] += 1
        if "video" in tag:
            self.writer.add_video(tag, value, self.tag_step[tag], fps=15)
        elif "images" in tag:
            self.writer.add_images(tag, value, self.tag_step[tag])
        elif "hist" in tag:
            self.writer.add_histogram(tag, value, self.tag_step[tag])
        else:
            self.writer.add_scalar(tag, value, self.tag_step[tag])

## 5. 学習

### Config

In [None]:
class Config:
    def __init__(self, **kwargs):
        # basic settings
        self.eval_freq = 1000
        self.seed = 1234
        self.ImageSize = 64
        self.ReplayBufferOnGPU = True

        # joint train agent
        self.SampleMaxSteps = 100_000
        self.BufferMaxLength = 100000
        self.BufferWarmUp = 1024
        self.NumEnvs = 1
        self.BatchSize = 16
        self.DemonstrationBatchSize = 4
        self.BatchLength = 64
        self.ImagineBatchSize = 1024
        self.ImagineDemonstrationBatchSize = 256
        self.ImagineContextLength = 8
        self.ImagineBatchLength = 16
        self.TrainDynamicsEverySteps = 1
        self.TrainAgentEverySteps = 1
        self.UseDemonstration = False
        self.SaveEverySteps = 2500

        # world models
        self.WorldModel_InChannels = 3
        self.WorldModel_TransformerMaxLength = 64
        self.WorldModel_TransformerHiddenDim = 512
        self.WorldModel_TransformerNumLayers = 2
        self.WorldModel_TransformerNumHeads = 8

        # Agent
        self.Agent_NumLayers = 2
        self.Agent_HiddenDim =  512
        self.Agent_Gamma = 0.985
        self.Agent_Lambda = 0.95
        self.Agent_EntropyCoef = 3E-4

### make env

In [None]:
def make_env(seed=None, img_size=64, max_steps=100_000):
    env = gym.make("ALE/MsPacman-v5")

    # シード固定
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    env = ResizeObservation(env, (img_size, img_size))
    env = RepeatAction(env=env, skip=4, max_steps=max_steps)

    return env

### Build vec env

In [None]:
# def build_vec_env(env_name, image_size, num_envs, seed):
#     # lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7
#     def lambda_generator(env_name, image_size):
#         return lambda: make_env(seed=seed)
#     env_fns = []
#     env_fns = [lambda_generator(env_name, image_size) for i in range(num_envs)]
#     vec_env = gym.vector.AsyncVectorEnv(env_fns=env_fns)
#     return vec_env

### Train World Models Step

In [None]:
def train_world_model_step(replay_buffer: ReplayBuffer, world_model: WorldModel, batch_size, demonstration_batch_size, batch_length, logger):
    obs, action, reward, termination = replay_buffer.sample(batch_size, demonstration_batch_size, batch_length)
    world_model.update(obs, action, reward, termination, logger=logger)

### World Model Imagine Data

In [None]:
def world_model_imagine_data(replay_buffer: ReplayBuffer,
                             world_model: WorldModel, agent: ActorCriticAgent,
                             imagine_batch_size, imagine_demonstration_batch_size,
                             imagine_context_length, imagine_batch_length,
                             log_video, logger):
    '''
    Sample context from replay buffer, then imagine data with world model and agent
    '''
    world_model.eval()
    agent.eval()

    sample_obs, sample_action, sample_reward, sample_termination = replay_buffer.sample(
        imagine_batch_size, imagine_demonstration_batch_size, imagine_context_length)
    # sample_obs: (B=1024, context=8, C, W, H)

    latent, action, reward_hat, termination_hat = world_model.imagine_data(
        agent, sample_obs, sample_action,
        imagine_batch_size=imagine_batch_size+imagine_demonstration_batch_size,
        imagine_batch_length=imagine_batch_length,
        log_video=log_video,
        logger=logger
    )
    # latent shape torch.Size([1024, 17, 1536]) 1536は32*32+transformer hidden_dim 17はimagine_length+1
    # action shape torch.Size([1024, 16])
    # reward_hat shape torch.Size([1024, 16])
    # termination_hat shape torch.Size([1024, 16])
    return latent, action, None, None, reward_hat, termination_hat

### Joint Train World Model Agent

In [None]:
def joint_train_world_model_agent(env_name, max_steps, num_envs, image_size,
                                  replay_buffer: ReplayBuffer,
                                  world_model: WorldModel, agent: ActorCriticAgent,
                                  train_dynamics_every_steps, train_agent_every_steps,
                                  batch_size, demonstration_batch_size, batch_length,
                                  imagine_batch_size, imagine_demonstration_batch_size,
                                  imagine_context_length, imagine_batch_length,
                                  save_every_steps, seed, logger, eval_freq):
    # create ckpt dir
    # os.makedirs(f"ckpt", exist_ok=True)

    # build vec env, not useful in the Atari100k setting
    # but when the max_steps is large, you can use parallel envs to speed up
    # vec_env = build_vec_env(env_name, image_size, num_envs=num_envs, seed=seed)
    env = make_env(seed=seed, max_steps=100_000)
    # print("Current env: " + colorama.Fore.YELLOW + f"{env_name}" + colorama.Style.RESET_ALL)

    # reset envs and variables
    sum_reward = 0
    current_obs, current_info = env.reset() # (H, W, C)
    context_obs = deque(maxlen=16)
    context_action = deque(maxlen=16)

    # sample and train
    for total_steps in tqdm(range(max_steps//num_envs)):
        # start = time.time()
        # sample part >>>
        if replay_buffer.ready():
            world_model.eval()
            agent.eval()
            with torch.no_grad():
                if len(context_action) == 0:
                    action = env.action_space.sample() # (action_dim=1, )
                else:
                    # print('context_latentで呼び出し', torch.cat(list(context_obs), dim=1).shape)
                    context_latent = world_model.encode_obs(torch.cat(list(context_obs), dim=1)) # (B=1?, L, n_classes*stoch_dim)
                    model_context_action = np.stack(list(context_action), axis=1) # (B, L)
                    model_context_action = torch.Tensor(model_context_action).cuda()
                    prior_flattened_sample, last_dist_feat = world_model.calc_last_dist_feat(context_latent, model_context_action)
                    action = agent.sample_as_env_action(
                        torch.cat([prior_flattened_sample, last_dist_feat], dim=-1),
                        greedy=False
                    )
            context_obs.append(rearrange(torch.Tensor(current_obs).cuda(), "H W C -> 1 1 C H W")/255)
            context_action.append([action]) # []でB次元追加してる
        else:
            action = env.action_space.sample() # (action_dim=1, )

        obs, reward, done, info = env.step(action)


        replay_buffer.append(current_obs, action, reward, np.logical_or(done, info["life_loss"]))

        if done:
            logger.log(f"sample/{env_name}_reward", sum_reward)
            logger.log(f"sample/{env_name}_episode_steps", current_info["episode_frame_number"]//4)  # framskip=4
            logger.log("replay_buffer/length", len(replay_buffer))
            sum_reward = 0
            obs, info = env.reset()
            done = False

        # update current_obs, current_info and sum_reward
        sum_reward += reward
        current_obs = obs
        current_info = info
        # end = time.time()
        # print('sample part time', end-start)
        # <<< sample part

        # train world model part >>>
        # start = time.time()
        if replay_buffer.ready() and total_steps % (train_dynamics_every_steps//num_envs) == 0:
            train_world_model_step(
                replay_buffer=replay_buffer,
                world_model=world_model,
                batch_size=batch_size,
                demonstration_batch_size=demonstration_batch_size,
                batch_length=batch_length,
                logger=logger
            )
        # end = time.time()
        # print('train world model time', end-start)
        # <<< train world model part

        # train agent part >>>
        if replay_buffer.ready() and total_steps % (train_agent_every_steps//num_envs) == 0 and total_steps*num_envs >= 0:
            if total_steps % (save_every_steps//num_envs) == 0:
                log_video = False # True
            else:
                log_video = False

            # start = time.time()
            with torch.no_grad():
                imagine_latent, agent_action, agent_logprob, agent_value, imagine_reward, imagine_termination = world_model_imagine_data(
                    replay_buffer=replay_buffer,
                    world_model=world_model,
                    agent=agent,
                    imagine_batch_size=imagine_batch_size,
                    imagine_demonstration_batch_size=imagine_demonstration_batch_size,
                    imagine_context_length=imagine_context_length,
                    imagine_batch_length=imagine_batch_length,
                    log_video=log_video,
                    logger=logger
                )
            # end = time.time()
            # print('imagine data time', end - start)
            # imagine_latent shape torch.Size([1024, 17, 1536]) 1536は32*32+transformer hidden_dim 17はimagine_length+1
            # agent_action shape torch.Size([1024, 16])
            # None
            # None
            # imagine_reward shape torch.Size([1024, 16])
            # imagine_termination shape torch.Size([1024, 16])

            # start = time.time()
            agent.update(
                latent=imagine_latent,
                action=agent_action,
                old_logprob=agent_logprob,
                old_value=agent_value,
                reward=imagine_reward,
                termination=imagine_termination,
                logger=logger
            )
            # end = time.time()
            # print('agent update time', end - start)
        # <<< train agent part

        # >>> eval part
        if replay_buffer.ready() and (total_steps+1) % eval_freq == 0:
            evaluate(world_model, agent, logger)
        # <<< eval part

        # save model per episode
        # if total_steps % (save_every_steps//num_envs) == 0:
        #     print(f"Saving model at total steps {total_steps}")
        #     torch.save(world_model.state_dict(), f"ckpt/world_model_{total_steps}.pth")
        #     torch.save(agent.state_dict(), f"ckpt/agent_{total_steps}.pth")


### Build World Model

In [None]:
def build_world_model(config, action_dim):
    return WorldModel(
        in_channels=config.WorldModel_InChannels,
        action_dim=action_dim,
        transformer_max_length=config.WorldModel_TransformerMaxLength,
        transformer_hidden_dim=config.WorldModel_TransformerHiddenDim,
        transformer_num_layers=config.WorldModel_TransformerNumLayers,
        transformer_num_heads=config.WorldModel_TransformerNumHeads
    ).cuda()

### Build Agent

In [None]:
def build_agent(conf, action_dim):
    return ActorCriticAgent(
        feat_dim=32*32+config.WorldModel_TransformerHiddenDim,
        num_layers=config.Agent_NumLayers,
        hidden_dim=config.Agent_HiddenDim,
        action_dim=action_dim,
        gamma=config.Agent_Gamma,
        lambd=config.Agent_Lambda,
        entropy_coef=config.Agent_EntropyCoef,
    ).cuda()

### evaluate

In [None]:
def evaluate(world_model, agent, logger):
    env = make_env(seed=1234, img_size=64, max_steps=None)

    world_model.eval()
    agent.eval()

    obs, info = env.reset()
    done = False
    total_reward = 0
    frames = []
    actions = []
    context_obs = deque(maxlen=16)
    context_action = deque(maxlen=16)

    while not done:
        with torch.no_grad():
            if len(context_action) == 0:
                action = env.action_space.sample()
            else:
                context_latent = world_model.encode_obs(torch.cat(list(context_obs), dim=1))
                model_context_action = np.stack(list(context_action), axis=1)
                model_context_action = torch.Tensor(context_action).cuda()
                model_context_action = model_context_action.view(1, -1) # チェック
                prior_flattened_sample, last_dist_feat = world_model.calc_last_dist_feat(context_latent, model_context_action)
                action = agent.sample_as_env_action(
                    torch.cat([prior_flattened_sample, last_dist_feat], dim=-1),
                    greedy=True
                )

        context_obs.append(rearrange(torch.Tensor(obs).cuda(), "H W C -> 1 1 C H W")/255)
        context_action.append([action])

        obs, reward, done, info = env.step(action)

        total_reward += reward
        frames.append(obs)
        actions.append(action)

    print('Total Reward:', total_reward)
    logger.log(f"evaluation/reward", total_reward)
    world_model.train()
    agent.train()

### 実際の学習

In [None]:
# ignore warnings
import warnings
warnings.filterwarnings('ignore')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

config = Config()

# set seed
set_seed()
# tensorboard writer
logger = Logger(path="./logs")
%tensorboard --logdir './logs'

# getting action_dim with dummy env
dummy_env = make_env(seed=0)
action_dim = dummy_env.action_space.n

# build world model and agent
world_model = build_world_model(config, action_dim)
agent = build_agent(config, action_dim)

# build replay buffer
replay_buffer = ReplayBuffer(
    obs_shape=(config.ImageSize, config.ImageSize, 3),
    num_envs=config.NumEnvs,
    max_length=config.BufferMaxLength,
    warmup_length=config.BufferWarmUp,
    store_on_gpu=config.ReplayBufferOnGPU
)

# # judge whether to load demonstration trajectory
# if config.UseDemonstration:
#     print(colorama.Fore.MAGENTA + f"loading demonstration trajectory from {args.trajectory_path}" + colorama.Style.RESET_ALL)
#     replay_buffer.load_trajectory(path=args.trajectory_path)

# train
joint_train_world_model_agent(
    env_name='PacMan',
    num_envs=config.NumEnvs,
    max_steps=config.SampleMaxSteps,
    image_size=config.ImageSize,
    replay_buffer=replay_buffer,
    world_model=world_model,
    agent=agent,
    train_dynamics_every_steps=config.TrainDynamicsEverySteps,
    train_agent_every_steps=config.TrainAgentEverySteps,
    batch_size=config.BatchSize,
    demonstration_batch_size=config.DemonstrationBatchSize if config.UseDemonstration else 0,
    batch_length=config.BatchLength,
    imagine_batch_size=config.ImagineBatchSize,
    imagine_demonstration_batch_size=config.ImagineDemonstrationBatchSize if config.UseDemonstration else 0,
    imagine_context_length=config.ImagineContextLength,
    imagine_batch_length=config.ImagineBatchLength,
    save_every_steps=config.SaveEverySteps,
    seed=config.seed,
    logger=logger,
    eval_freq=config.eval_freq
)


## 6. モデルの保存

In [None]:
# # モデルの保存(Google Driveの場合）
# from google.colab import drive
# drive.mount('/content/drive')

# trained_models.save("drive/MyDrive/Colab Notebooks/")

In [None]:
# モデルの保存
os.makedirs("./SavedModels", exist_ok=True)
trained_models = TrainedModels(
    world_model,
    agent
)
trained_models.save("./SavedModels")

## 7. 学習済みパラメータで評価  


In [None]:
# 環境の読み込み
env = make_env()
device = "cuda" if torch.cuda.is_available() else "cpu"
config = Config()

# 学習済みモデルの読み込み
# rssm = RSSM(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes, action_dim).to(device)
# encoder = Encoder().to(device)
# decoder = Decoder(cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
# reward_model =  RewardModel(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
# discount_model = DiscountModel(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
# actor = Actor(action_dim, cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
# critic = Critic(cfg.mlp_hidden_dim, cfg.rnn_hidden_dim, cfg.state_dim, cfg.num_classes).to(device)
action_dim = env.action_space.n
world_model = build_world_model(config, action_dim)
agent = build_agent(config, action_dim)

trained_models = TrainedModels(
    world_model,
    agent
)

trained_models.load("SavedModels", device)

In [None]:
# 結果を動画で観てみるための関数
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML


def display_video(frames):
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0], cmap="gray")
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(HTML(anim.to_jshtml(default_mode='once')))
    plt.close()

In [None]:
env = make_env(seed=1234, img_size=64, max_steps=None)

world_model.eval()
agent.eval()

obs, info = env.reset()
done = False
total_reward = 0
frames = []
actions = []
context_obs = deque(maxlen=16)
context_action = deque(maxlen=16)

while not done:
    with torch.no_grad():
        if len(context_action) == 0:
            action = env.action_space.sample()
        else:
            context_latent = world_model.encode_obs(torch.cat(list(context_obs), dim=1))
            model_context_action = np.stack(list(context_action), axis=1)
            model_context_action = torch.Tensor(context_action).cuda()
            model_context_action = model_context_action.view(1, -1) # なんかnp.stackしてるのにshapeおかしい
            prior_flattened_sample, last_dist_feat = world_model.calc_last_dist_feat(context_latent, model_context_action)
            action = agent.sample_as_env_action(
                torch.cat([prior_flattened_sample, last_dist_feat], dim=-1),
                greedy=True
            )

    context_obs.append(rearrange(torch.Tensor(obs).cuda(), "H W C -> 1 1 C H W")/255)
    context_action.append([action])

    obs, reward, done, info = env.step(action)

    total_reward += reward
    frames.append(obs)
    actions.append(action)

print('Total Reward:', total_reward)

In [None]:
display_video(frames)

今回，評価を行う際のrepeat actionは1に設定しています．  
そのため，repeat actionをそれ以外に設定している場合，repeat actionの分だけ繰り返した行動を提出する形にしています．

In [None]:
# repeat actionに対応した行動に変換する
submission_actions = np.zeros(len(actions) * env._skip)
for start_idx in range(env._skip):
    submission_actions[start_idx::env._skip] = np.array(actions)

# np.save("drive/MyDrive/submission", submission_actions)
np.save("./submission", submission_actions)