In [2]:
# default_exp models

# module name here

> API details.

In [3]:
#hide
from nbdev.showdoc import *

In [4]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F

In [5]:
torch.cuda.is_available()

True

# H 

### Encoder

In [14]:
#export
class encoder_h(nn.Module):
    def __init__(self, Y_dim, H_dim, hidden_dim = 16):
        super().__init__()
        self.make_encoder(Y_dim, H_dim, hidden_dim)
  
    def make_encoder(self, Y_dim, H_dim, hidden_dim):
        self.net = nn.Sequential(nn.Linear(Y_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, hidden_dim), nn.Tanh())
        self.mu = nn.Linear(hidden_dim, H_dim)
        self.std = nn.Sequential(nn.Linear(hidden_dim, H_dim), nn.ReLU())
  
    def forward(self, y):
        hidden_state = self.net(y)
        mu, std = self.mu(hidden_state), self.std(hidden_state)
        return dist.Normal(mu, std)

# X

### Encoder

In [13]:
#export
class encoder_x(nn.Module):
    def __init__(self, I_dim, H_dim, X_dim, hidden_size=16, inducing_point_stride=None):
        super().__init__()
        self.inducing_point_stride = inducing_point_stride
        self.hidden_size = hidden_size
        self.make_network(I_dim, H_dim, X_dim, hidden_size)

    def make_network(self, I_dim, H_dim, X_dim, hidden_size):
        self.bilstm = nn.LSTM(input_size=I_dim+H_dim, hidden_size=hidden_size, batch_first=True, bidirectional=True)
        # self.densenet = nn.Sequential(nn.Linear(2*hidden_size, X_dim), nn.ReLU())
        self.mu = nn.Linear(2*hidden_size, X_dim)
        self.sigma = nn.Sequential(nn.Linear(2*hidden_size, X_dim), nn.ReLU())
  
    def forward(self, i_seq, h_seq):
        """
        i_seq: shape = (BS, T)
        h_seq: shape = (BS, T, H_dim)
        """
        BS, T = i_seq.shape
        assert(len(i_seq.shape)==2)
        lstm_input = torch.cat([i_seq.view(*i_seq.shape, 1), h_seq], dim=-1)
        hidden, _ = self.bilstm(lstm_input) #shape(hidden) = (BS, T, 2*hidden_size)
    
        if self.inducing_point_stride is not None:
            mu, sigma = self.mu(hidden[:, ::self.inducing_point_stride]), self.sigma(hidden[:, ::self.inducing_point_stride])
        else:
            mu, sigma = self.mu(hidden), self.sigma(hidden)
        return dist.Normal(mu, sigma)

# Message passing routines

### $l_n$

In [7]:
#export
def l_n_vectorized(I, X, H):
    """
    I: shape=(BS,S)
    X: shape=(BS,S)
    H: shape=(BS,S,H_dim)

    Output:
    out: shape=(BS,S,B,1)
    """
    # sort of a workaround
    Z = torch.arange(B).view(*torch.ones(len(X.shape), dtype=int), B).expand(*X.shape, -1)
    X = X.view(*X.shape, 1).expand(*X.shape, B)
    I = I.view(*I.shape, 1).expand(*I.shape, B)
    assert X.shape == I.shape == Z.shape
    ll_i = _decoder_i._log_likelihood(I, X, Z) #shape = (BS,S,B)
    ll_h = _decoder_h._log_likelihood(H) #shape=(BS,S,B,H_dim)
    # print(ll_i.shape, ll_h.shape)
    out = (ll_i + ll_h.sum(-1)).unsqueeze(-1) #shape for each = (BS,S,B)
    return out #shape for each = (BS,S,B,1)

### $\psi_n$

In [10]:
#export
def psi_n_vectorized(X):
    """
    X: shape=(BS,S)
    """
    init_shape = X.shape
    X_mod = _decoder_z.transform_x(X) #shape = (BS,S,B)
    X_mod = X_mod.unsqueeze(-2) #shape = (BS,S,1,B) - make 'row vector'

    unnorm_logits = _decoder_z.P.view(1,1,*_decoder_z.P.shape) + X_mod #shape = (1,1,B,B) + (BS,S,1,B) = (BS,S,B,B)
    normalizer = torch.logsumexp(unnorm_logits, dim=-1, keepdim=True) #shape = (BS,S,B,1)
    return (unnorm_logits - normalizer) #shape=(BS,S,B,B)

### $m_n$

In [11]:
#export
def compute_message(psi_matrix, l_vector, prev_message_vector):
    """
    psi_matrix: shape=(..., B,B)
    l_vector: shape=(..., B,1)
    prev_message_vector: shape=(..., B,1)

    Output: 
    next_message_vector: shape=(..., B,1)
    """
    raw_messages = psi_matrix + l_vector + prev_message_vector #shape=(..., B, B)
    next_message_vector = torch.logsumexp(raw_messages, dim=-2, keepdim=True) #shape=(..., 1, B) - 'row vector'
    next_message_vector = next_message_vector.transpose(-1,-2)

    #Alternate way of doing above two steps in one
    # _next_message_vector = torch.logsumexp(raw_messages, dim=-2, keepdim=False).unsqueeze(-1)
    # print((next_message_vector - _next_message_vector).sum())

    return next_message_vector

In [12]:
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_simulations.ipynb.
Converted 02_models.ipynb.
Converted index.ipynb.
