In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal
from torch.distributions import kl_divergence
import numpy as np
import time

In [2]:
torch.__version__

'1.0.1'

In [3]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [4]:
def create_synthetic_data(time_step = 10, num_sample = 16, marker_dim=3):
    marker = torch.randn(time_step, num_sample, marker_dim).to(device)
#     points_ = np.random.rand(time_step, num_sample) * 1.
#     cum_sum_points =  np.cumsum(points_, axis = 0)
#     t = np.stack([cum_sum_points, points_], axis = 2)
#     x, t  = marker.tolist(), t.tolist()
#     x = torch.tensor(x)
#     t = torch.tensor(t)
#     data = {'x':x, 't': t}
    data = marker
    return data

In [5]:
class DMM(nn.Module):
    def __init__(self, marker_type='real', marker_dim=3, latent_dim=20):
        super().__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus
        
        self.marker_type = marker_type
        self.latent_dim = latent_dim
        self.inference_hidden_dim = latent_dim
        
        self._inference_net(x_dim=marker_dim, z_dim=latent_dim, inference_hidden_dim=latent_dim)
        self._gated_transition_net(latent_dim, latent_dim)
        self._emission_net(z_dim=latent_dim, x_dim=marker_dim, emission_hidden_dim=latent_dim)

    def _gated_transition_net(self, z_dim, transition_dim):
        # mean
        self.gating_unit_hidden = nn.Sequential(
            nn.Linear(z_dim, transition_dim),
            nn.ReLU()
        )
        self.gating_unit_for_proposed_mean = nn.Sequential(
            nn.Linear(transition_dim, z_dim),
            nn.Sigmoid()
        )
        self.nonlinear_proposed_mean = nn.Linear(transition_dim, z_dim)
        
        self.linear_mean = nn.Linear(z_dim, z_dim)
        
        self.linear_mean.weight.data = torch.eye(z_dim)
        self.linear_mean.bias.data = torch.zeros(z_dim)
        
        # variance
        self.z_next_var_net = nn.Sequential(
            nn.Linear(z_dim, z_dim),
            nn.Softplus()
        )
    def _transition(self, z_prev):
        """
            This is the function that computes the p_theta (z_t | z_{t-1})
            Input:
            z_prev: Tensor of shape TxBSxlatent_dim (if real)
        """
        # Linear mean (linear function of previous state)
        linear_proposed_mean = self.linear_mean(z_prev)
        
        # Nonlinear (proposed) mean using gates
        hidden = self.gating_unit_hidden(z_prev)
        gating_unit = self.gating_unit_for_proposed_mean(hidden)
        proposed_mean = self.nonlinear_proposed_mean(hidden)
        
        # Combine nonlinear and linear options of means to choose
        z_next_mean = (1 - gating_unit) * linear_proposed_mean \
                        + gating_unit*proposed_mean
            
        z_next_var_int = self.relu(proposed_mean)
        z_next_var = self.z_next_var_net(z_next_var_int)
        
#         z_next_dist = Normal(z_next_mean, z_next_var.sqrt())
#         z_next = z_next_dist.sample()
        return z_next_mean, z_next_var

    
    def _emission_net(self, z_dim, x_dim, emission_hidden_dim):
        self.emission_net = nn.Sequential(
            nn.Linear(z_dim, emission_hidden_dim),
            nn.ReLU(),
            nn.Linear(emission_hidden_dim, emission_hidden_dim),
            nn.ReLU()
        )
        if self.marker_type == 'real':
            self.x_mean = nn.Linear(emission_hidden_dim, x_dim)
            self.x_var = nn.Sequential(
                nn.Linear(emission_hidden_dim, x_dim),
                nn.Softplus()
            )
        elif self.marker_type == 'binary':
            pass
    
    def _emission(self, z_seq):
        """
            Input:
            z_t: Tensor of shape TxBSxlatent_dim (if real)
        """
        if self.marker_type == 'real':
            x_mean = self.x_mean(z_seq)
            x_var = self.x_var(z_seq)
            x_dist = Normal(x_mean, x_var.sqrt())    
        elif self.marker_type == 'binary':
            pass
#         x_sample = x_dist.sample()
        return x_dist
    
    
    
    def _combiner_fn(self, z_prev, rnn_state_right):
        hidden_combined = (self.combiner_net(z_prev) + rnn_state_right)/3
        
        return hidden_combined
    
    def _inference_net(self, x_dim, z_dim, inference_hidden_dim):
        
        self.backward_rnn = nn.GRUCell(x_dim, inference_hidden_dim)
        self.combiner_net = nn.Sequential(
            nn.Linear(z_dim, inference_hidden_dim),
            nn.Tanh()
        )
        self.posterior_mean_net = nn.Linear(inference_hidden_dim, z_dim)
        self.posterior_var_net = nn.Sequential(
            nn.Linear(inference_hidden_dim, z_dim),
            nn.Softplus()
        )
        
    def _inference(self, x_seq):
        """
            Input:
            x_seq: Tensor of shape TxBSxmarker_dim (if real)
        """
        # Generate z_0. Can choose to learn it as a parameter or fix it
        # Recurrently, generate h_t
        # Recurrently, generate z_t using z_{t-1} and h_t
        T, BS, _ = x_seq.shape
        z_dim = self.latent_dim
        z_0 = nn.Parameter(torch.zeros(BS, z_dim)).to(device)
        h_t = torch.zeros(BS, self.inference_hidden_dim).to(device)
        z_seq = [z_0]
        z_means = []
        z_vars = []
        
        z_prev = z_0
        # x_t[t] will be BS x marker_dim
        # go backward, starting from T --> 1
        for t in range(T-1,-1,-1):
            h_t_right = self.backward_rnn(x_seq[t], h_t)
            h_combined = self._combiner_fn(z_prev, h_t_right)
            posterior_z_t_mean = self.posterior_mean_net(h_combined)
            posterior_z_t_var = self.posterior_var_net(h_combined)
            epsilon = torch.randn_like(posterior_z_t_mean)
            posterior_z_t_sample = posterior_z_t_mean + posterior_z_t_var.sqrt()*epsilon

            z_means.append(posterior_z_t_mean)
            z_vars.append(posterior_z_t_var)
            z_seq.append(posterior_z_t_sample)
            z_prev = posterior_z_t_sample
        
        z_seq = torch.stack(z_seq, dim=0)
        z_means = torch.stack(z_means, dim=0)
        z_vars = torch.stack(z_vars, dim=0)
        return z_seq, z_means, z_vars
    
    def forward(self, x_seq):
        # Get a sampled sequence (z_0 --> z_T) and
        # the means and the vars for
        # q_phi(z_seq | x_seq) = q_phi(z_t | z_{t-1}, x_seq)
        posterior_seq_sample, posterior_means, posterior_vars = self._inference(x_seq)
        
        # Get means and vars for p_theta (z) = p_theta (z_t | z_{t-1})
        prior_means, prior_vars = self._transition(posterior_seq_sample[:-1])
        
        # Get p_theta (z_t | z_{t-1})
        prior_dist = Normal(prior_means, prior_vars.sqrt())
        
        # Convert means and vars to torch distributions
        posterior_dist = Normal(posterior_means, posterior_vars.sqrt())
        
        # Get p_theta(x_t | z_t)
        reconstructed_x_dist = self._emission(posterior_seq_sample[:-1])
        
        # log p_theta (x_seq | z) = sum( log p_theta (x_t | z_t))
        neg_log_likelihood = -reconstructed_x_dist.log_prob(x_seq).sum()
        
        # KL divergence between the variational posterior distribution and generative prior
        # KL (q_phi (z_t | z_{t-1}, x_seq) || p_theta (z_t | z_{t-1}))
        kl = kl_divergence(posterior_dist, prior_dist).sum()
        
        return neg_log_likelihood, kl

In [21]:
def train(model, epoch, data, optimizer, batch_size, val_data):
    start = time.time()
    model.train()
    train_loss_nll = 0
    train_loss_kl = 0
    T, n_train, _ = data.shape
    _, n_val, _ = val_data.shape

    optimizer.zero_grad()
    idxs = np.random.permutation(n_train)
    for i in range(0, n_train, batch_size):
        nll, kl = model(data)
        # want to maximize ll and minimize kl
        loss = kl + nll
        loss.backward()
        
        train_loss_nll += nll.item()
        train_loss_kl += kl.item()
        
    optimizer.step()
    end = time.time()
    
    val_loss = 0.
    if val_data is not None:
        n_val = len(val_data)
        with torch.no_grad():
            val_nll, val_kl = model(val_data)
            loss = val_kl + val_nll
            val_loss += loss.item()
    
    print("Epoch: {}, NLL Loss: {}, KL Div: {}, Val Loss: {}, Time took: {}".format(epoch, train_loss_nll/(n_train*T), \
                                                                        train_loss_kl/(n_train*T), val_loss/n_val, \
                                                                        (end-start)))

In [22]:
def trainer(model: nn.Module, data: torch.Tensor = None, val_data: torch.Tensor = None, lr = 1e-3, epoch=100, batch_size=64):
    optimizer = Adam(model.parameters(), lr=lr)
    
    for epoch_number in range(epoch):
        train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model
    

# Generate synthetic data

In [23]:
from dmm_data.load import loadSyntheticData

In [24]:
data = loadSyntheticData()
torch.tensor(data['valid']['tensor']).to(device).shape

torch.Size([1000, 10, 3])

# Train

In [25]:
def main():
    model = DMM().to(device)
#     data = create_synthetic_data(num_sample=5000, time_step=25)
    data = loadSyntheticData()
    train_data = torch.tensor(data['train']['tensor']).transpose(0,1).float().to(device)
#     val_data = create_synthetic_data(num_sample=150)
    val_data = torch.tensor(data['valid']['tensor']).transpose(0,1).float().to(device)
    trainer(model, data=train_data, val_data=val_data)

In [26]:
main()

Epoch: 0, NLL Loss: 6531.77752, KL Div: 482.53597125, Val Loss: 45087.096875, Time took: 4.147772312164307
Epoch: 1, NLL Loss: 6502.3032575, KL Div: 473.157955, Val Loss: 44730.43125, Time took: 4.257425308227539
Epoch: 2, NLL Loss: 6468.971035, KL Div: 464.02153375, Val Loss: 44401.15, Time took: 4.1038818359375
Epoch: 3, NLL Loss: 6440.5760875, KL Div: 455.2881903125, Val Loss: 44360.003125, Time took: 4.263967752456665
Epoch: 4, NLL Loss: 6410.5647925, KL Div: 446.7399396875, Val Loss: 44056.04375, Time took: 4.250853061676025
Epoch: 5, NLL Loss: 6382.286125, KL Div: 438.4461834375, Val Loss: 43901.9625, Time took: 4.143693923950195
Epoch: 6, NLL Loss: 6352.29315, KL Div: 430.3370934375, Val Loss: 43619.346875, Time took: 4.125028371810913
Epoch: 7, NLL Loss: 6324.6470325, KL Div: 422.500525, Val Loss: 43088.50625, Time took: 4.2200026512146
Epoch: 8, NLL Loss: 6296.6050075, KL Div: 414.8400184375, Val Loss: 42938.353125, Time took: 4.402120113372803
Epoch: 9, NLL Loss: 6268.45923, 

Epoch: 75, NLL Loss: 3911.964495, KL Div: 157.282063515625, Val Loss: 25934.334375, Time took: 4.065443515777588
Epoch: 76, NLL Loss: 3850.25502, KL Div: 156.675699375, Val Loss: 25259.784375, Time took: 4.058127164840698
Epoch: 77, NLL Loss: 3787.831135, KL Div: 156.05655234375, Val Loss: 24957.309375, Time took: 4.201474905014038
Epoch: 78, NLL Loss: 3726.237705, KL Div: 155.400989765625, Val Loss: 24559.0046875, Time took: 3.957115411758423
Epoch: 79, NLL Loss: 3665.3924375, KL Div: 154.688304140625, Val Loss: 24247.1671875, Time took: 4.100284099578857
Epoch: 80, NLL Loss: 3605.688645, KL Div: 153.870247734375, Val Loss: 23763.8265625, Time took: 4.148946523666382
Epoch: 81, NLL Loss: 3546.691155, KL Div: 152.952621953125, Val Loss: 23490.5890625, Time took: 4.242663145065308
Epoch: 82, NLL Loss: 3489.479225, KL Div: 151.910897890625, Val Loss: 22979.2203125, Time took: 4.147438287734985
Epoch: 83, NLL Loss: 3433.3776875, KL Div: 150.70961203125, Val Loss: 22756.240625, Time took: 