In [152]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, Predictive, NUTS, MCMC, config_enumerate
from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoMultivariateNormal
from pyro.optim import Adam, ClippedAdam
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [153]:
X_train_tensor, X_val_tensor, X_test_tensor = torch.load("./data/X_train_tensor.pt"), \
                                                torch.load("./data/X_val_tensor.pt"), \
                                                torch.load("./data/X_test_tensor.pt")
U_train_tensor, U_val_tensor, U_test_tensor = torch.load("./data/U_train_tensor.pt"), \
                                                torch.load("./data/U_val_tensor.pt"), \
                                                torch.load("./data/U_test_tensor.pt")
N_t_train, N_t_valid, N_t_test = np.load("./data/N_t_train.npy"), \
                                np.load("./data/N_t_val.npy"), \
                                np.load("./data/N_t_test.npy")


In [157]:
class RFNDataset(Dataset):
    """Spatio-temporal demand modelling dataset."""
    def __init__(self, X, U):
        self.X = X
        self.U = U

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        X_i, U_i = self.X[idx].float(), self.U[idx].float()
        return X_i, U_i

In [240]:
class RMDN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, K=10):
        super(RMDN, self).__init__()
        # Define parameters
        self.input_dim = input_dim
        self.hidden_dim= hidden_dim
        self.output_dim= output_dim
        self.K = K 

        # Define LSTM
        self.lstm = nn.LSTM(input_size=self.input_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=1)
        # Fully connected layer 
        self.lstm_to_hidden  = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim)
        # Take output of fully connected layer and feed to layers for GMM components
        self.hidden_to_loc   = nn.Linear(in_features=self.hidden_dim, out_features=self.K*self.output_dim)
        self.hidden_to_scale = nn.Linear(in_features=self.hidden_dim, out_features=self.K*self.output_dim)
        self.hidden_to_cov   = nn.Linear(in_features=self.hidden_dim, out_features=self.K)
        self.hidden_to_mix   = nn.Linear(in_features=self.hidden_dim, out_features=self.K)
        
        # Functions
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.softplus = nn.Softplus()
        
    def forward(self, U):
        
        # Feed through LSTM
        h, _ = self.lstm(U.view(-1, 1, self.input_dim))
        
        # Fully connected
        h = self.relu(self.lstm_to_hidden(h))
        # Compute mean values
        loc   = self.hidden_to_loc(h).view(-1, self.K, self.output_dim)
        # Compute variances (must be positive)
        scale = self.softplus(self.hidden_to_scale(h)).view(-1, self.K, self.output_dim)
        # Compute covariances
        cov   = self.hidden_to_cov(h)
        # Compute mixture components (must sum to 1)
        pi    = self.softmax(self.hidden_to_mix(h).view(-1,self.K, 1))
        # Create full covariance matrix
        Cov = torch.zeros((self.K, self.output_dim, self.output_dim))
        for k in range(self.K):
            Cov[k, 0, 0] = scale[:, k, 0]
            Cov[k, 1, 1] = scale[:, k, 1]
            Cov[k, 0, 1] = cov[:, :, k]
            Cov[k, 1, 0] = cov[:, :, k]
        
        return loc, pi, Cov
    
class PyroRMDN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, K, grid=32):
        super(PyroRMDN, self).__init__()
        self.RMDN = RMDN(input_dim, hidden_dim, output_dim, K)
        self.K = K
        self.output_dim = output_dim
        self.grid = grid
        self.skipped = 0
        
    def model(self, X=None, U=None, mask=None, batch_size=1):
                
        # Number of sequences
        N = len(U)
        # Number of time steps
        T_max = U.size(1)
        # Batch size
        b = min(N, batch_size)

        assert U.shape == (N, T_max, self.grid, self.grid)
        
        x_samples = torch.zeros((b, T_max, max(mask), 2))
        
        pyro.module("RMDN",self)
        
        with pyro.plate("data", N, dim=-2):
            for t in pyro.markov(range(0, T_max)):
                loc, pi, Cov = self.RMDN(U[:, t, :, :])
                if mask[t] == 0:
                    self.skipped += 1
                    continue
                
                with pyro.plate('density_%d'%t, size=mask[t], dim=-1):
                    assignment = pyro.sample('assignment_%d'%t, dist.Categorical(pi.view(-1, 10)))
                    _loc = loc[:, assignment, :].view(-1, 2)
                    _Cov = Cov[assignment, :, :].view(-1, 2, 2)
                    if X is None:
                        x_samples[:, t, :mask[t]] = pyro.sample('x_%d'%t, dist.MultivariateNormal(loc=_loc,covariance_matrix=_Cov),obs=None)
                    else:
                        x_samples[:, t, :mask[t]] = pyro.sample('x_%d'%t, dist.MultivariateNormal(loc=_loc,covariance_matrix=_Cov),obs=X[:, t, :mask[t], :])
        return x_samples, loc, pi, Cov
    
    def guide(self, X=None, U=None, mask=None):
        pass


In [6]:
#### Nmax = 259
#net = RMDN(input_dim=K*K, hidden_dim=128, output_dim=2, K=64)
#loc, pi, Cov = net(U[:, 0, :, :])

In [250]:
dataset  = RFNDataset(X_test_tensor, U_test_tensor)
dataloader= DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

In [251]:
pyronet = PyroRMDN(input_dim=K*K, hidden_dim=10, output_dim=2, K=10)
optimizer = pyro.optim.ClippedAdam({"lr": 0.001, "clip_norm": 5.})
svi = SVI(pyronet.model, pyronet.guide, optimizer, TraceEnum_ELBO(num_particles=1, max_plate_nesting=1))

In [None]:
pyro.clear_param_store()
num_epochs = 500
train_losses = []
val_losses = []
for i in range(num_epochs):
    for x_i, u_i in dataloader:
        pyronet.train()
        
        loss = svi.step(x_i, u_i, N_t_test) /(N_t_test.sum())
        del x_i, u_i
        train_losses.append(loss)
        #pyronet.eval()
        #validation_loss = pyronet._get_log_likelihood(X=X_val_tensor, U=U_val_tensor, mask=N_t_valid, 
                                                      #num_particles=1).sum().item() / N_t_valid.sum()
        #val_losses.append(validation_loss)
        
        if i%100 == 99 or i == 0:
            print("Epoch: {0}, Loss: {1:.3f}".format(i+1, loss))#, validation_loss))
        

Epoch: 1, Loss: 548132.948
Epoch: 100, Loss: 443999.307


In [233]:
x_samples, loc, pi, Cov = pyronet.model(U=U_val_tensor, mask=N_t_valid, X=X_val_tensor)

(1, 718, 32, 32)


In [224]:
N_t_test[12:14]

array([ 0, 20])

In [205]:
 U_test_tensor[:, 13, :, :].size(1)

32

In [206]:
len(U_test_tensor[:, 13, :, :])

1