In [22]:
# Imports

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
#import pyro
#import pyro.distributions as dist
#from pyro.infer import SVI, TraceEnum_ELBO, Predictive, NUTS, MCMC, config_enumerate
#from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoMultivariateNormal
#from pyro.optim import Adam, ClippedAdam
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.distributions as dist
from torch.distributions import MultivariateNormal as MN

In [2]:
# Load data files

X_train_tensor, X_val_tensor, X_test_tensor =   torch.load("./data/X_train_tensor.pt"), \
                                                torch.load("./data/X_val_tensor.pt"),   \
                                                torch.load("./data/X_test_tensor.pt")
U_train_tensor, U_val_tensor, U_test_tensor =   torch.load("./data/U_train_tensor.pt"), \
                                                torch.load("./data/U_val_tensor.pt"),   \
                                                torch.load("./data/U_test_tensor.pt")
N_t_train, N_t_valid, N_t_test = np.load("./data/N_t_train.npy"), \
                                 np.load("./data/N_t_val.npy"),   \
                                 np.load("./data/N_t_test.npy")


In [3]:
# Class for extracting data

class RFNDataset(Dataset):
    """Spatio-temporal demand modelling dataset."""
    def __init__(self, X, U):
        self.X = X
        self.U = U

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        X_i, U_i = self.X[idx].float(), self.U[idx].float()
        return X_i, U_i

In [107]:
# Recurrent Mixture Density network

class RMDN(nn.Module):
    """
    Pure PyTorch class for Recurrent Mixture Density Network
    Inputs:
        input_dim:  dimension of input tensor U
        hidden_dim: number of hidden units to be used in various hidden layers
        output_dim: dimension of output, lat/lon
        K: number of mixture components
        
    Outputs:
        loc:    Tensor of mean values for Gaussians
        pi:     Mixture components
        Cov:    Covariance matrices
        hidden: Hidden states
    
    """
    def __init__(self, input_dim, hidden_dim, output_dim, K=10):
        super(RMDN, self).__init__()
        # Define parameters
        self.input_dim = input_dim
        self.hidden_dim= hidden_dim
        self.output_dim= output_dim
        self.K = K 
        self.tril_indices = torch.tril_indices(row=output_dim, col=output_dim, offset=-1)

        # Define LSTM
        self.lstm = nn.LSTM(input_size=self.input_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=1)
        # Fully connected layer 
        self.lstm_to_hidden     = nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim)
        # Take output of fully connected layer and feed to layers for GMM components
        self.hidden_to_loc      = nn.Linear(in_features=self.hidden_dim, out_features=self.K*self.output_dim)
        self.hidden_to_sigma    = nn.Linear(in_features=self.hidden_dim, out_features=self.K*self.output_dim)
        self.hidden_to_off_diag = nn.Linear(in_features=self.hidden_dim, out_features=self.K)
        self.hidden_to_mix      = nn.Linear(in_features=self.hidden_dim, out_features=self.K)
        
        # Functions
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.softplus = nn.Softplus()
        
    def forward(self, U, hidden):     
        # Feed through LSTM
        y, hidden = self.lstm(U.view(-1, 1, self.input_dim), hidden)
        # Fully connected
        y = self.relu(self.lstm_to_hidden(y))
        # Compute mean values
        loc   = self.hidden_to_loc(y).view(-1, self.K, self.output_dim)
        # Compute variances (must be positive)
        sigma = self.softplus(self.hidden_to_sigma(y)).view(-1, self.K, self.output_dim)
        # Compute covariances
        cov   = self.hidden_to_off_diag(y).view(-1, self.K, 1)
        # Compute mixture components (must sum to 1)
        pi    = self.softmax(self.hidden_to_mix(y).view(-1,self.K, 1))
        # Create full covariance matrix
        cov_tril = torch.zeros((U.shape[0], self.K, self.output_dim, self.output_dim))
        for i in range(self.K):
            cov_tril[:, i, self.tril_indices[0], self.tril_indices[1]] = cov[:, i, :]
            cov_tril[:, i] += torch.diag_embed(sigma[:, i, :])
        
        return (loc, pi, cov_tril), hidden
    
    
    def get_loglikelihood(self, U, X, mask):
        
        logprob = 0
        T_max = U.size(1)
        with torch.no_grad():
            hidden = self.init_hidden()
            for t in range(0, T_max):
                # Extract components for current GMM
                (loc, pi, cov), hidden = self.forward(U=U[:, t, :, :], hidden=hidden)
                # Compute loglikelihood for all datapoints for current time interval
                for tt in range(mask[t]):
                    current_prob = 0
                    for i in range(self.K):
                        current_prob += pi.squeeze()[i] * torch.exp(MN(loc=loc.squeeze()[i], scale_tril=cov[:, i, :, :].squeeze()).log_prob(X[:, t, tt, :]))
                    logprob += torch.log(current_prob)
        return logprob
    
    
    
    
    def init_hidden(self):
        # Initialise hidden states
        return (torch.zeros(1, 1, self.hidden_dim), torch.zeros(1, 1, self.hidden_dim))
    


In [122]:
model = RMDN(input_dim=48*48, hidden_dim=128, output_dim=2, K=10)

In [128]:
model.get_loglikelihood(U=U_test_tensor, X=X_test_tensor, mask=N_t_test)

KeyboardInterrupt: 

In [None]:
def criterion(model, U, X, mask, hidden=None):
    if hidden is None:
        hidden = model.init_hidden()
    
    return model.get_loglikelihood(U)
    

In [129]:
model(U=U[:, 0, :, :], hidden=model.init_hidden())

((tensor([[[ 0.0763, -0.0403],
           [ 0.0314,  0.0460],
           [ 0.0090, -0.0716],
           [-0.0544, -0.0400],
           [ 0.0343, -0.0209],
           [-0.0853, -0.0392],
           [-0.0504, -0.0655],
           [-0.0294,  0.0762],
           [-0.0146,  0.0440],
           [-0.0540,  0.0273]]], grad_fn=<ViewBackward>),
  tensor([[[0.1042],
           [0.0952],
           [0.1004],
           [0.0952],
           [0.1064],
           [0.1012],
           [0.0974],
           [0.0965],
           [0.1019],
           [0.1016]]], grad_fn=<SoftmaxBackward>),
  tensor([[[[ 0.6735,  0.0000],
            [ 0.0604,  0.7017]],
  
           [[ 0.7204,  0.0000],
            [-0.0033,  0.6792]],
  
           [[ 0.7085,  0.0000],
            [ 0.0698,  0.6797]],
  
           [[ 0.7022,  0.0000],
            [-0.0272,  0.6868]],
  
           [[ 0.6991,  0.0000],
            [ 0.0258,  0.6965]],
  
           [[ 0.7128,  0.0000],
            [-0.1092,  0.6724]],
  
           [[ 0

In [127]:
U.shape

torch.Size([1, 302, 48, 48])

In [115]:
dataset  = RFNDataset(X_train_tensor, U_train_tensor)
dataloader= DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

In [116]:
model = RMDN(input_dim=48*48, hidden_dim=128, output_dim=2, K=10)
optimizer = pyro.optim.Adam({"lr": 0.001})

In [None]:
num_epochs = 500
train_losses = []
val_losses = []
train_ll = []
validation_ll =  []
for i in range(num_epochs):
    for X_i, U_i in dataloader:
        # Set model to training
        pyronet.train()
        # Take step and update parameters
        loss = svi.step(X_i, U_i, N_t_train, hidden) /(N_t_train.sum())
        # Save current loss
        train_losses.append(loss)       
        if i%100 == 99 or i == 0:
            # Set model to evaluation
            pyronet.eval()
            # Compute and save training and validation log likelihood
            train_ll_i      = pyronet.model(X=X_i, U=U_i, mask=N_t_train, validation=True)
            validation_ll_i = pyronet.model(X=X_val_tensor, U=U_val_tensor, mask=N_t_valid, validation=True)
            train_ll.append(train_ll_i)
            validation_ll.append(validation_ll_i)
            
            print(f"Epoch: {i+1}, \tLoss: {loss:.3f}, \tTraining LL: {train_ll_i.numpy()[0]},\tValidation LL: {validation_ll_i.numpy()[0]}")
        