In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal
from torch.distributions import kl_divergence
import numpy as np
import time

In [2]:
torch.__version__

'1.0.1'

In [3]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [4]:
def create_synthetic_data(time_step = 10, num_sample = 16, marker_dim=3):
    marker = torch.randn(time_step, num_sample, marker_dim).to(device)
#     points_ = np.random.rand(time_step, num_sample) * 1.
#     cum_sum_points =  np.cumsum(points_, axis = 0)
#     t = np.stack([cum_sum_points, points_], axis = 2)
#     x, t  = marker.tolist(), t.tolist()
#     x = torch.tensor(x)
#     t = torch.tensor(t)
#     data = {'x':x, 't': t}
    data = marker
    return data

In [5]:
class DMM(nn.Module):
    def __init__(self, marker_type='real', marker_dim=3, latent_dim=20):
        super().__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus
        
        self.marker_type = marker_type
        self.latent_dim = latent_dim
        self.inference_hidden_dim = latent_dim
        
        self._inference_net(x_dim=marker_dim, z_dim=latent_dim, inference_hidden_dim=latent_dim)
        self._gated_transition_net(latent_dim, latent_dim)
        self._emission_net(z_dim=latent_dim, x_dim=marker_dim, emission_hidden_dim=latent_dim)

    def _gated_transition_net(self, z_dim, transition_dim):
        # mean
        self.gating_unit_hidden = nn.Sequential(
            nn.Linear(z_dim, transition_dim),
            nn.ReLU()
        )
        self.gating_unit_for_proposed_mean = nn.Sequential(
            nn.Linear(transition_dim, z_dim),
            nn.Sigmoid()
        )
        self.nonlinear_proposed_mean = nn.Linear(transition_dim, z_dim)
        
        self.linear_mean = nn.Linear(z_dim, z_dim)
        
        self.linear_mean.weight.data = torch.eye(z_dim)
        self.linear_mean.bias.data = torch.zeros(z_dim)
        
        # variance
        self.z_next_var_net = nn.Sequential(
            nn.Linear(z_dim, z_dim),
            nn.Softplus()
        )
    def _transition(self, z_prev):
        """
            This is the function that computes the p_theta (z_t | z_{t-1})
            Input:
            z_prev: Tensor of shape TxBSxlatent_dim (if real)
        """
        # Linear mean (linear function of previous state)
        linear_proposed_mean = self.linear_mean(z_prev)
        
        # Nonlinear (proposed) mean using gates
        hidden = self.gating_unit_hidden(z_prev)
        gating_unit = self.gating_unit_for_proposed_mean(hidden)
        proposed_mean = self.nonlinear_proposed_mean(hidden)
        
        # Combine nonlinear and linear options of means to choose
        z_next_mean = (1 - gating_unit) * linear_proposed_mean \
                        + gating_unit*proposed_mean
            
        z_next_var_int = self.relu(proposed_mean)
        z_next_var = self.z_next_var_net(z_next_var_int)
        
#         z_next_dist = Normal(z_next_mean, z_next_var.sqrt())
#         z_next = z_next_dist.sample()
        return z_next_mean, z_next_var

    
    def _emission_net(self, z_dim, x_dim, emission_hidden_dim):
        self.emission_net = nn.Sequential(
            nn.Linear(z_dim, emission_hidden_dim),
            nn.ReLU(),
            nn.Linear(emission_hidden_dim, emission_hidden_dim),
            nn.ReLU()
        )
        if self.marker_type == 'real':
            self.x_mean = nn.Linear(emission_hidden_dim, x_dim)
            self.x_var = nn.Sequential(
                nn.Linear(emission_hidden_dim, x_dim),
                nn.Softplus()
            )
        elif self.marker_type == 'binary':
            pass
    
    def _emission(self, z_seq):
        """
            Input:
            z_t: Tensor of shape TxBSxlatent_dim (if real)
        """
        if self.marker_type == 'real':
            x_mean = self.x_mean(z_seq)
            x_var = self.x_var(z_seq)
            x_dist = Normal(x_mean, x_var.sqrt())    
        elif self.marker_type == 'binary':
            pass
#         x_sample = x_dist.sample()
        return x_dist
    
    
    
    def _combiner_fn(self, z_prev, rnn_state_right):
        hidden_combined = (self.combiner_net(z_prev) + rnn_state_right)/3
        
        return hidden_combined
    
    def _inference_net(self, x_dim, z_dim, inference_hidden_dim):
        
        self.backward_rnn = nn.GRUCell(x_dim, inference_hidden_dim)
        self.combiner_net = nn.Sequential(
            nn.Linear(z_dim, inference_hidden_dim),
            nn.Tanh()
        )
        self.posterior_mean_net = nn.Linear(inference_hidden_dim, z_dim)
        self.posterior_var_net = nn.Sequential(
            nn.Linear(inference_hidden_dim, z_dim),
            nn.Softplus()
        )
        
    def _inference(self, x_seq):
        """
            Input:
            x_seq: Tensor of shape TxBSxmarker_dim (if real)
        """
        # Generate z_0. Can choose to learn it as a parameter or fix it
        # Recurrently, generate h_t
        # Recurrently, generate z_t using z_{t-1} and h_t
        T, BS, _ = x_seq.shape
        z_dim = self.latent_dim
        z_0 = nn.Parameter(torch.zeros(BS, z_dim)).to(device)
        h_t = torch.zeros(BS, self.inference_hidden_dim).to(device)
        z_seq = [z_0]
        z_means = []
        z_vars = []
        
        z_prev = z_0
        # x_t[t] will be BS x marker_dim
        # go backward, starting from T --> 1
        for t in range(T-1,-1,-1):
            h_t_right = self.backward_rnn(x_seq[t], h_t)
            h_combined = self._combiner_fn(z_prev, h_t_right)
            posterior_z_t_mean = self.posterior_mean_net(h_combined)
            posterior_z_t_var = self.posterior_var_net(h_combined)
            epsilon = torch.randn_like(posterior_z_t_mean)
            posterior_z_t_sample = posterior_z_t_mean + posterior_z_t_var.sqrt()*epsilon

            z_means.append(posterior_z_t_mean)
            z_vars.append(posterior_z_t_var)
            z_seq.append(posterior_z_t_sample)
            z_prev = posterior_z_t_sample
        
        z_seq = torch.stack(z_seq, dim=0)
        z_means = torch.stack(z_means, dim=0)
        z_vars = torch.stack(z_vars, dim=0)
        return z_seq, z_means, z_vars
    
    def forward(self, x_seq):
        # Get a sampled sequence (z_0 --> z_T) and
        # the means and the vars for
        # q_phi(z_seq | x_seq) = q_phi(z_t | z_{t-1}, x_seq)
        posterior_seq_sample, posterior_means, posterior_vars = self._inference(x_seq)
        
        # Get means and vars for p_theta (z) = p_theta (z_t | z_{t-1})
        prior_means, prior_vars = self._transition(posterior_seq_sample[:-1])
        
        # Get p_theta (z_t | z_{t-1})
        prior_dist = Normal(prior_means, prior_vars.sqrt())
        
        # Convert means and vars to torch distributions
        posterior_dist = Normal(posterior_means, posterior_vars.sqrt())
        
        # Get p_theta(x_t | z_t)
        reconstructed_x_dist = self._emission(posterior_seq_sample[:-1])
        
        # log p_theta (x_seq | z) = sum( log p_theta (x_t | z_t))
        log_likelihood = -reconstructed_x_dist.log_prob(x_seq).sum()
        
        # KL divergence between the variational posterior distribution and generative prior
        # KL (q_phi (z_t | z_{t-1}, x_seq) || p_theta (z_t | z_{t-1}))
        kl = kl_divergence(posterior_dist, prior_dist).sum()
        
        return log_likelihood, kl

In [6]:
def train(model, epoch, data, optimizer, batch_size, val_data):
    start = time.time()
    model.train()
    train_loss = 0
    n_train, n_val = len(data[0]), 1.

    optimizer.zero_grad()
    idxs = np.random.permutation(n_train)
    for i in range(0, n_train, batch_size):
        ll, kl = model(data)
        # want to maximize ll - kl, therefore minimize the negative of that
        loss = kl - ll
        loss.backward()
        train_loss += loss.item()
    optimizer.step()
    end = time.time()
    
    val_loss = 0.
    if val_data is not None:
        n_val = len(val_data)
        with torch.no_grad():
            val_ll, val_kl = model(val_data)
            loss = val_kl - val_ll
            val_loss += loss.item()
    
    print("Epoch: {}, NLL Loss: {}, Val Loss: {}, Time took: {}".format(epoch, train_loss/n_train,\
                                                                        val_loss/n_val, (end-start)))

In [7]:
def trainer(model: nn.Module, data: torch.Tensor = None, val_data: torch.Tensor = None, lr = 1e-3, epoch=100, batch_size=64):
    optimizer = Adam(model.parameters(), lr=lr)
    
    for epoch_number in range(epoch):
        train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model
    

In [8]:
def main():
    model = DMM().to(device)
    data = create_synthetic_data(num_sample=5000, time_step=25)
    val_data = create_synthetic_data(num_sample=150)
    trainer(model, data=data, val_data=val_data)

In [9]:
main()

Epoch: 0, NLL Loss: -306.35235546875, Val Loss: -321.7994140625, Time took: 0.2870516777038574
Epoch: 1, NLL Loss: -318.53726171875, Val Loss: -306.61875, Time took: 0.25124239921569824
Epoch: 2, NLL Loss: -332.0558828125, Val Loss: -325.8939453125, Time took: 0.2702515125274658
Epoch: 3, NLL Loss: -343.149201171875, Val Loss: -334.184521484375, Time took: 0.35918617248535156
Epoch: 4, NLL Loss: -358.420451171875, Val Loss: -353.5908203125, Time took: 0.420757532119751
Epoch: 5, NLL Loss: -369.276701171875, Val Loss: -357.4369140625, Time took: 0.2658872604370117
Epoch: 6, NLL Loss: -380.89758984375, Val Loss: -355.3448486328125, Time took: 0.2644665241241455
Epoch: 7, NLL Loss: -391.230166015625, Val Loss: -386.770947265625, Time took: 0.24984431266784668
Epoch: 8, NLL Loss: -402.086724609375, Val Loss: -395.80068359375, Time took: 0.3705635070800781
Epoch: 9, NLL Loss: -413.643384765625, Val Loss: -395.46435546875, Time took: 0.24553585052490234
Epoch: 10, NLL Loss: -424.116041015625

Epoch: 86, NLL Loss: -6304.22571875, Val Loss: -6915.47734375, Time took: 0.24769234657287598
Epoch: 87, NLL Loss: -7397.0903125, Val Loss: -7903.09375, Time took: 0.25017309188842773
Epoch: 88, NLL Loss: -8739.37565625, Val Loss: -9640.79296875, Time took: 0.25121188163757324
Epoch: 89, NLL Loss: -10242.9719375, Val Loss: -10105.78125, Time took: 0.2428126335144043
Epoch: 90, NLL Loss: -12209.9978125, Val Loss: -12383.54921875, Time took: 0.2744133472442627
Epoch: 91, NLL Loss: -14566.4756875, Val Loss: -18554.2265625, Time took: 0.34703564643859863
Epoch: 92, NLL Loss: -17671.6365625, Val Loss: -20096.8484375, Time took: 0.2709007263183594
Epoch: 93, NLL Loss: -21692.634375, Val Loss: -25919.8625, Time took: 0.27404356002807617
Epoch: 94, NLL Loss: -25846.2455, Val Loss: -27557.325, Time took: 0.35396409034729004
Epoch: 95, NLL Loss: -32020.518125, Val Loss: -34979.95625, Time took: 0.35999059677124023
Epoch: 96, NLL Loss: -38054.328, Val Loss: -49078.228125, Time took: 0.33701825141