### Learning RSSM

The RSSM is an important part of what makes Dreamer a world model. It helps compress past experiences into a hidden state and predict the future, given some actions. Basically, it tells the agent what the world might look like after taking actions. 

It is a state-space model, which represents the world as a *latent state* s_t. This is a smaller, more abstract representation of the observation image, i.e., the encoder of the RSSM alters the images' representation into a more compact, abstract encoding. 

The recurrent structure updates this state over time, while taking into account the history of the past sequence of actions. 

It also has a prediction abililty. Using the latent state, it will predict the next observation (e.g., make an image by decoding the latent representation), it will predict rewards, and it will predict the next latent state without the need of an observation (this is where its *imagination* rollout happens).

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
# code for manual kl; done in KL_divergence.ipynb
def _kl(posterior_logits, prior_logits):
    q_log = F.log_softmax(posterior_logits, dim=-1)       # (B, G, C)
    p_log = F.log_softmax(prior_logits, dim=-1)           # (B, G, C)
    q     = q_log.exp()
    kl    = (q * (q_log - p_log)).sum(dim=-1).sum(dim=-1) # (B)
    return kl

In [None]:
torch.manual_seed(44)

batch_size  = 4
num_groups  = 4
num_classes = 4
X = torch.randn(batch_size, num_groups, num_classes)
B, G, C = X.shape

# action size depends upon the kind of input
# for now, just set to 4
size_action        = 4

# (1) choose hyperparameters: obs embedding, memory size, stochastic size
size_obs_embed     = 64
size_hidden        = 128
size_stochastic    = num_groups * num_classes # done to properly represent a corresponding output to each class
# (2) initialize the action and prev state (deterministic and stochastic)
    # need 4 inputs into the whole RSSM cell, only if there is an obs embedding from the encoder; if not, only relies on 'prior' network (imagination step)
        # the previous action, the hidden state, the stochastic, and if there is an obs encoding.

# create the initial state for hidden and stochastic
# these are initialized at the beginning of the RSSM
prev_hidden     = torch.zeros(batch_size, size_hidden)
prev_stochastic = torch.zeros(batch_size, size_stochastic)

# need prev action and an obs embedding
# the action depends on what were modelling and the obs embedding depends on the encoder
# previous action is initialized to a tensor of zeros -> zeros because of it is neutral and therefore doesn't bias the model
prev_action     = torch.zeros(batch_size, size_action)
# this always comes from the encoder, so w/o one right now, just use random input
obs_embed       = torch.randn(batch_size, size_obs_embed)

# (3) feed the inputs into the cell:
    # (a) update the GRU memory by feeding in the prev action, prev hidden state, and the prev stochastic 
        # calculates the update gate (z_t), reset gate (r_t), potential hidden state (h_t_tilda)
        # uses these calculation to return the current hidden state h_t
            # the hidden state is a summation of how much info from the past and the current states are passed through
            # z_t is an addition of both the past and current info being passed through a nonlinearity (sigmoid)
                # a z_t closer to 1 says "pass through all the current info and forget the past"
                # a z_t closer to 0 says "keep all the info from the past and don't pass through the current info"
            # h_t_tilda takes r_t and multiplies it with the past info; this product is added with the current info
                # r_t being 1 says "use all the past info when calculating the new potential state"
                    # the reset gate determines how much of the past to take into consideration of the new potential state
                    # when it is 0, no past info is taken into consideration, only the current input is passed through
            # h_t takes z_t and multiplies it and its 1 - z_t against the potential hidden state and the previous state, respectively
                # goal: decide how much of the memory and current states are a part of the new hidden state
        # so in this context, the GRU is using the prev action, the prev uncertainty (the stochastic), and the prev memory to figure out what the next outcome may likely be.
            # i.e., it means "given what I just did and what I think I am seeing, how should I update my understanding of the situation so I can predict what comes next?""

# the previous action and previous stochastic, concatenated along the last dimension 
gru       = nn.GRUCell(size_stochastic + size_action, size_hidden) 
gru_input = torch.cat([prev_stochastic, prev_action], dim=-1)
hidden_t  = gru(gru_input, prev_hidden)

    # (b) calculate the new stochastic with the PRIOR network
        # uses the current hidden state calculated by the Gdsadsadsaddsadsadsaddsadsadsadsadsadsadsadsaadsadsaadsadsadits (discrete representation) as (B, G, C)

prior_head   = nn.Linear(size_hidden, size_stochastic)
prior_logits = prior_head(hidden_t)
prior_logits = prior_logits.view(B, G, C)

    # (c) if there is an obs, calc the new stochastic with the POSTERIOR network
            # calculate the logits from the posterior net (using hidden state and the observation)
            # get the one hot vectors for each group using gumbel_softmax, which is the new stochastic
                # note: the randomness is injected into the logits in F.gumbel_softmax by adding gumbel noise to the logits
            # calculate KL for world model loss
                # measures how much the prior network's distribution differs from the posterior network's distribution
                # goal: want to adjust the prior network's weights to align more closely with posterior 
                # reason: posterior utilizes the observation in training but the prior doesn't
                # so we want to adjust the prior to be closer to the posterior
        # if no obs, then use the stochastic from the PRIOR network (imagination run)
            # get logits and one hot vectors only using prior
            # return the new stochastic from the onehots 



postr_head = nn.Linear(size_hidden + size_obs_embed, size_stochastic)
if obs_embed is not None:
    # get postr network logits
    joined       = torch.cat([hidden_t, obs_embed], dim=-1)
    postr_logits = postr_head(joined)
    postr_logits = postr_logits.view(B, G, C)

    # get stochastic using gumbel
    # reshape to (B * G, C) because want to say that each row is independent
    postr_logits_gumbel = postr_logits.view(B * G, C)
    onehots = F.gumbel_softmax(postr_logits_gumbel, tau=0.8, dim=-1, hard=True) # tau is gradually decayed per training step, but leave it for now
    stochastic_t = onehots.view(B, G * C)

    # get kl for world model loss
    kl = _kl(postr_logits, prior_logits)

else:
    prior_logits_gumbel = prior_logits.view(B * G, C)
    onehots = F.gumbel_softmax(prior_logits_gumbel, tau=0.8, dim=-1, hard=True)
    stochastic_t = onehots.view(B, G * C)

    # (d) return the new hidden state and the distribution associated with the stochastic

new_state = (hidden_t, stochastic_t)


In [18]:
new_state[0].shape, new_state[1].shape

(torch.Size([4, 128]), torch.Size([4, 16]))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# DreamerV3 separates the RSSM into three parts: (1) Encoder, (2) Sequence Model, (3) Dynamics Predictor
# Sequence Model: 
    # "a sequence model with recurrent state ht predicts the sequence of these representations given past actions at−1."
    # ht = fϕ(ht−1, zt−1, at−1)
    # inputs the prev hidden state, prev stochastic state, prev action
    # outputs the hidden state
    # gated recurrent unit does this operation of producing the hidden state
    # "The sequence model is a GRU with block-diagonal recurrent weights of 8 blocks to allow for a large number of memory units without quadratic increase in parameters and FLOPs"
    # "The input to the GRU at each time step is a linear embedding of the sampled latent zt, of the action at, and of the recurrent state to allow mixing between blocks."

class SequenceModel(nn.Module):
    
    def __init__(self, size_stochastic, size_action, size_hidden, embed_dim, num_blocks=8):
        super().__init__()

        assert size_hidden % num_blocks == 0


        self.embed_dim       = embed_dim
        self.block_size      = size_hidden // num_blocks
        self.size_stochastic = size_stochastic
        self.size_action     = size_action
        self.size_hidden     = size_hidden


        self.embed_z = nn.Linear(size_stochastic, embed_dim)
        self.embed_a = nn.Linear(size_action,     embed_dim)
        self.embed_h = nn.Linear(size_hidden,     embed_dim)

        self.gru_blocks = nn.ModuleList([
            nn.GRUCell(embed_dim, self.block_size) for _ in range(num_blocks)
        ])


    def forward(self, prev_hidden, prev_stochastic, prev_action):
        x_t = self.embed_a(prev_action) + self.embed_h(prev_hidden) + self.embed_z(prev_stochastic)
        x_t = F.elu(x_t) # e

        h_slices = []
        for i, cell in enumerate(self.gru_blocks):
            start    = i * self.block_size
            end      = (i + 1) * self.block_size
            h_prev_i = prev_hidden[:, start:end]
            h_i      = cell(x_t, h_prev_i)
            h_slices.append(h_i)
        h_t = torch.cat(h_slices, dim=-1)
        
        return h_t


# Dynamics Predictor:
    # zˆt ∼ pϕ(ˆzt | ht)
    # given the current hidden state, outputs the distribution of the stochastic latent
        # prior network gives the prediction based on no observational input

class DynamicsPredictor(nn.Module):
    
    def __init__(self, size_hidden, num_groups, num_classes, tau_init=1.0):
        super().__init__()

        self.G = num_groups
        self.C = num_classes
        self.size_stochastic = num_groups * num_classes
        self.tau = tau_init

        self.prior = nn.Linear(size_hidden, self.size_stochastic)

    def dist(self, h_t): 
        B = h_t.size(0)
        logits = self.prior(h_t)
        logits = logits.view(B, self.G, self.C)
        return logits
    
    def sample(self, logits, hard=True):
        onehots = F.gumbel_softmax(logits, hard=hard, tau=self.tau, dim=-1)
        z_t = onehots.view(B, self.G * self.C)
        return logits, z_t

# Encoder:
    # "an encoder maps sensory inputs xt to stochastic representations zt."
    # zt ∼ qϕ(zt | ht, xt)
    # given the hidden state and observational input, output the distribution of the stochastic latent
        # posterior network gives this output

class Encoder(nn.Module):
    
    def __init__(self, size_obs_embed, size_hidden, num_groups, num_classes, tau_init=1.0):
        super().__init__()
        self.G = num_groups
        self.C = num_classes
        self.size_stochastic = num_groups * num_classes
        self.tau = tau_init

        self.postr = nn.Linear(size_hidden + size_obs_embed, self.size_stochastic)

    def dist(self, x_t, h_t):
        B = x_t.size(0)
        joined = torch.cat([h_t, x_t], dim=-1)
        logits = self.postr(joined)
        logits = logits.view(B, self.G, self.C)
        return logits

    def forward(self, logits, hard=True):
        onehots = F.gumbel_softmax(logits, tau=self.tau, hard=hard, dim=-1)
        z_t = onehots.view(B, self.G * self.C) 
        return z_t

