In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Normal
from torch.distributions import kl_divergence

In [2]:
torch.__version__

'1.0.1'

In [3]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


# Define RMTPP Module

In [4]:
class RMTPP(nn.Module):
    def __init__(self, marker_type='real', marker_dim=20):
        super().__init__()
        
        self.marker_dim = marker_dim
        
        # Dimensions for embedding inputs
        self.marker_embed_dim = 64
        self.time_embed_dim = 64
        # Networks for embedding inputs
        self.create_embedding_nets()
        
        # This is the layer that encodes the history
        self.hidden_layer_dim = 128
        # Create RNN layer Network
        self.create_rnn()
        
        # Hidden shared layer size (bw mu and var)
        # while generating marker from hidden state
        self.marker_shared_dim = 64
        # Create Network for Marker generation from hidden_seq
        self.create_marker_generation_net()
        self.create_time_likelihood_net()

    ############ UTILITY METHODS #############
        
    def _one_hot_marker(self, marker_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx1
        """
        pass


    
    ############ NETWORKS #############    
    
    def create_rnn(self):
        """
            Input:
            marker_embed_dim: dimension of embedded markers
            time_embed_dim: dimension of embedded times,intervals
            hidden_layer_dim: dimension of hidden state of recurrent layer
        """

        ### Vanilla RNN
#         self.rnn = nn.RNN(
#             input_size=self.marker_embed_dim+self.time_embed_dim,
#             hidden_size=self.hidden_layer_dim,
#             nonlinearity='relu'
#         )
        ### GRU
        self.rnn = nn.GRU(
            input_size=self.marker_embed_dim+self.time_embed_dim,
            hidden_size=self.hidden_layer_dim,
        )
    
    def create_embedding_nets(self):
        # marker_dim is passed. timeseries_dim is 2
        self.marker_embedding_net = nn.Linear(self.marker_dim, self.marker_embed_dim)
        
        self.time_embedding_net = nn.Sequential(
            nn.Linear(2, self.time_embed_dim),
            nn.ReLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )
        
    def create_marker_generation_net(self):
        """
            Generate network to create marker sufficient statistics using
            rnn's hidden layer
        """
        self.marker_gen_hidden = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.marker_shared_dim),
            nn.ReLU()
        )
        self.generated_marker_mu = nn.Linear(self.marker_shared_dim, self.marker_dim)
        self.generated_marker_var = nn.Sequential(
            nn.Linear(self.marker_shared_dim, self.marker_dim),
            nn.Softplus()
        )
    
    def create_time_likelihood_net(self):
        self.h_influence = nn.Linear(self.hidden_layer_dim, 1, bias=False)
        self.base_intensity = nn.Parameter(torch.zeros(1,1,1))
        self.wt = nn.Parameter(torch.tensor([[[1e-6]]]))
    
    
    ############ METHODS #############
    
    def _embed_data(self, marker_seq, time_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx marker_dim
            time_seq: Tensor of shape TxBSx 2
            Output:
            marker_seq_emb: Tensor of shape T x BS x marker_embed_dim
            time_seq_emb: Tensor of shape T x BS x time_embed_dim
        """
        marker_seq_emb = self.marker_embedding_net(marker_seq)
        time_seq_emb = self.time_embedding_net(time_seq)
        return marker_seq_emb, time_seq_emb
    
    
    
    def marker_log_likelihood(self, h_seq, marker_seq):
        """
        Use the h_seq to generate the distribution for the markers,
        and use that distribution to compute the log likelihood of the marker_seq
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    marker_seq : Tensor of shape T x BS x marker_dim
            Output:
                    log_likelihood_marker_seq : T x BS x marker_dim
        """
        marker_gen_shared = self.marker_gen_hidden(h_seq)
        mu, var = self.generated_marker_mu(marker_gen_shared), self.generated_marker_var(marker_gen_shared)
        
        marker_gen_dist = Normal(mu, var.sqrt())
        log_likelihood_marker_seq = marker_gen_dist.log_prob(marker_seq)
        
        return log_likelihood_marker_seq
        
        
    def time_log_likelihood(self, h_seq, time_seq):
        """
        Use the h_seq to compute the log likelihood of the time_seq
        using the formula in the paper
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    time_seq : Tensor of shape T x BS x 2 . Last dimension is [times, intervals]
            Output:
                    log_likelihood_time_seq : T x BS x 1
        """
#         import pdb; pdb.set_trace()
        past_influence = self.h_influence(h_seq)
        current_influence = self.wt * time_seq[:,:,1:2]
        base_intensity = self.base_intensity
        
        term1 = past_influence + current_influence + base_intensity
        term2 = past_influence + base_intensity
        
        # After factorizing the formula in the paper
        log_likelihood_time_seq = term1 + (term2.exp() - term1.exp())/self.wt
        
        return log_likelihood_time_seq
    
    #############################################
    
    def forward(self, marker_seq, time_seq, **kwargs):
        # Transform markers and timesteps into the embedding spaces
        marker_seq_emb, time_seq_emb = self._embed_data(marker_seq, time_seq)
        T,BS,_ = marker_seq_emb.shape
        
        # Run RNN over the concatenated sequence [marker_seq_emb, time_seq_emb]
        time_marker_combined = torch.cat([marker_seq_emb, time_seq_emb], dim=-1)
        h_0 = torch.zeros(1, BS, self.hidden_layer_dim).to(device)
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        hidden_combined = torch.cat([h_0, hidden_seq], dim=0)
        
        # compute the marker and time log likelihoods
        # h_0 is used to generate marker_1 and so on...
        marker_ll = self.marker_log_likelihood(hidden_combined[:-1], marker_seq)
        # h_0 is used to generate timestamp_1 and so on...
        time_ll = self.time_log_likelihood(hidden_combined[:-1], time_seq)
        
        likelihood_loss = marker_ll.sum() + time_ll.sum()
        NLL = -likelihood_loss
        
        # NLL is used for optimization,
        # the individual LL values are used for logging
        return NLL, [-marker_ll.sum().item(), -time_ll.sum().item() ]
        

In [5]:
class HRMTPP(RMTPP):
    def __init__(self, latent_dim=20, **kwargs):
        self.latent_dim = latent_dim
        super().__init__(**kwargs)
        

        
        self.create_inference_net()
        self.create_marker_generation_net()
        self.create_time_likelihood_net()
    
    ## Utility Methods ##
    def _reparameterize(self, mu, var):
        epsilon = torch.randn_like(mu)
        return mu + epsilon*var.sqrt()
    
    def create_inference_net(self):
        self.inference_rnn = nn.GRU(
            input_size = self.marker_embed_dim+self.time_embed_dim,
            hidden_size = self.hidden_layer_dim
        )
        
        self.inference_intermediate_net = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.hidden_layer_dim),
            nn.ReLU(),
        )
        self.posterior_mean_net = nn.Linear(self.hidden_layer_dim, self.latent_dim)
        self.posterior_var_net = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.latent_dim),
            nn.Softplus()
        )
        
    def create_marker_generation_net(self):
        """
            Generate network to create marker sufficient statistics using
            rnn's hidden layer and latent variable
        """
        self.marker_gen_hidden = nn.Sequential(
            nn.Linear(self.hidden_layer_dim+self.latent_dim, self.marker_shared_dim),
            nn.ReLU()
        )
        self.generated_marker_mu = nn.Linear(self.marker_shared_dim, self.marker_dim)
        self.generated_marker_var = nn.Sequential(
            nn.Linear(self.marker_shared_dim, self.marker_dim),
            nn.Softplus()
        )
    
    def create_time_likelihood_net(self):
        self.h_influence = nn.Linear(self.hidden_layer_dim+self.latent_dim, 1, bias=False)
        self.base_intensity = nn.Parameter(torch.zeros(1,1,1))
        self.wt = nn.Parameter(torch.randn(1,1,1))
    
        
    def _inference(self, hidden_seq):
        """
        Use the hidden_seq to compute the posterior
        q(z | x) = NN(RNN(x(1...T))) = NN(h_T)
        Also, compute a sampled value and return
            Input:  
                    hidden_seq   : Tensor of shape (T+1) x BS x hidden_layer_dim (first timestep is h_0)
            Output:
                    z_sampled: Tensor of shape 1 x BS x latent_dim
                    z_mean: Tensor of shape 1 x BS x latent_dim
                    z_var: Tensor of shape 1 x BS x latent_dim
        """
        
        intermediate_layer = self.inference_intermediate_net(hidden_seq[-1:])
        z_mean = self.posterior_mean_net(intermediate_layer)
        z_var = self.posterior_var_net(intermediate_layer)
        z_sampled = self._reparameterize(z_mean, z_var)
        
        return z_sampled, z_mean, z_var
    
    def hz_combined(self, hidden_seq, z):
        """
        Concatenate the z to the hidden_seq along the last dimension
            Input:  
                    hidden_seq   : Tensor of shape (T+1) x BS x hidden_layer_dim (first timestep is h_0)
                    z: Tensor of shape 1 x BS x latent_dim
            Output:
                    hz = Tensor of shape (T+1) x BS x (hidden_layer_dim+latent_dim)
        """
        # Extract timelength
        T, _, _ = hidden_seq.shape
        # Expand z on the time dimension, leaving everything else the same
        z_broadcast = z.expand(T, -1, -1)
        
        hz = torch.cat([hidden_seq, z_broadcast], dim=-1)
        
        return hz
    
    #############################################
    
    def forward(self, marker_seq, time_seq, anneal=1.):
        # Transform markers and timesteps into the embedding spaces
        marker_seq_emb, time_seq_emb = self._embed_data(marker_seq, time_seq)
        T,BS,_ = marker_seq_emb.shape
        
        # Run RNN over the concatenated sequence [marker_seq_emb, time_seq_emb]
        time_marker_combined = torch.cat([marker_seq_emb, time_seq_emb], dim=-1)
        h_0 = torch.zeros(1, BS, self.hidden_layer_dim).to(device)
        # Run RNN
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        # Append h_0 to h_1 .. h_T
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        ## Inference
        # Get the sampled value and (mean + var) latent variable
        # using the hidden state sequence
        posterior_sample, posterior_mean, posterior_var = self._inference(hidden_seq)
        posterior_dist = Normal(posterior_mean, posterior_var.sqrt())

        # Prior is just a Normal(0,1) dist
        prior_dist = Normal(0,1)

        ## Generative Part
        
        # Use the embedded markers and times to create another set of 
        # hidden vectors. Can reuse the h_0 and time_marker combined computed above

        # Use an RNN to summarize the x1 ... xT sequence
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        # Combine hidden_seq and z to form the input for the generative part
        hz_combined = self.hz_combined(hidden_seq, posterior_sample)
        
        # compute the marker and time log likelihoods
        # z,h_0 is used to generate marker_1 and so on...
        marker_ll = self.marker_log_likelihood(hz_combined[:-1], marker_seq)
        # z,h_0 is used to generate timestamp_1 and so on...
        time_ll = self.time_log_likelihood(hz_combined[:-1], time_seq)
        
        likelihood_loss = marker_ll.sum() + time_ll.sum()
        NLL = -likelihood_loss
        
        KL = kl_divergence(posterior_dist, prior_dist).sum()
        
        loss = NLL + anneal*10*KL
        
        # NLL and KL are used for optimization,
        # the individual LL values are used for logging
        return loss, [-marker_ll.sum().item(), -time_ll.sum().item(), KL.item()]
        

# Training

In [6]:
from utils_ import generate_mpp, test_val_split

In [7]:
from trainer import train
from hrmtpp import hrmtpp
from rmtpp import rmtpp

In [8]:
def trainer(model, data = None, val_data=None, lr= 1e-2, l2_reg=1e-2, epoch = 200, batch_size = 200):
    if data == None:
        data, val_data = generate_mpp()

    optimizer = Adam(model.parameters(), lr=lr, weight_decay=l2_reg)

    for epoch_number in range(epoch):
        train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model

In [9]:
def main(model, data, val_data):
    model = model().to(device)
#     data, _ = generate_mpp(type='hawkes', num_sample=1000)
#     val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
    print("Times: Data Shape: {}, Val Data Shape: {}".format(data['t'].shape, val_data['t'].shape))
    print("Markers: Data Shape: {}, Val Data Shape: {}".format(data['x'].shape, val_data['x'].shape))
    trainer(model, data=data, val_data=val_data)

## Cache data

In [10]:
# data, _ = generate_mpp(type='hawkes', num_sample=1000)
# val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
data, _ = generate_mpp(type='autoregressive', time_step=50, num_sample=1000, num_clusters=10, m=8)
data, val_data = test_val_split(data, val_ratio=.2)

# Experimental results

## RMTPP

In [11]:
main(model=RMTPP, data=data, val_data=val_data)

Times: Data Shape: torch.Size([50, 8000, 2]), Val Data Shape: torch.Size([50, 2000, 2])
Markers: Data Shape: torch.Size([50, 8000, 20]), Val Data Shape: torch.Size([50, 2000, 20])
Epoch: 0, NLL Loss: 1534.8873046875, Val Loss: 62688.44921875, Time took: 0.5179219245910645
Train loss Meta Info:  [1470.94384375   63.94345837]
Val Loss Meta Info:  [57992.555, 4695.8971875]

Epoch: 1, NLL Loss: 1567.135203125, Val Loss: 60602.47265625, Time took: 0.4047243595123291
Train loss Meta Info:  [1449.47568359  117.65953296]
Val Loss Meta Info:  [58507.18, 2095.2965625]

Epoch: 2, NLL Loss: 1513.94424609375, Val Loss: 61989.109375, Time took: 0.40695929527282715
Train loss Meta Info:  [1462.4589375    51.48530933]
Val Loss Meta Info:  [57767.58, 4221.53]

Epoch: 3, NLL Loss: 1536.696671875, Val Loss: 59942.40234375, Time took: 0.4000675678253174
Train loss Meta Info:  [1444.10030469   92.59635291]
Val Loss Meta Info:  [57271.87, 2670.53625]

Epoch: 4, NLL Loss: 1484.109796875, Val Loss: 60032.1953

Epoch: 43, NLL Loss: 1464.01020703125, Val Loss: 59310.0546875, Time took: 0.39210081100463867
Train loss Meta Info:  [1418.81104688   45.19917462]
Val Loss Meta Info:  [56758.26, 2551.79453125]

Epoch: 44, NLL Loss: 1463.31809375, Val Loss: 59423.22265625, Time took: 0.396686315536499
Train loss Meta Info:  [1418.78383203   44.53426324]
Val Loss Meta Info:  [56758.415, 2664.81125]

Epoch: 45, NLL Loss: 1464.9036796875, Val Loss: 59143.26953125, Time took: 0.3937511444091797
Train loss Meta Info:  [1418.76585156   46.13782446]
Val Loss Meta Info:  [56759.995, 2383.2771875]

Epoch: 46, NLL Loss: 1463.4636875, Val Loss: 59021.4453125, Time took: 0.3920786380767822
Train loss Meta Info:  [1418.79946484   44.66423236]
Val Loss Meta Info:  [56759.27, 2262.17359375]

Epoch: 47, NLL Loss: 1463.66427734375, Val Loss: 58927.5390625, Time took: 0.4002261161804199
Train loss Meta Info:  [1418.78257031   44.88170203]
Val Loss Meta Info:  [56757.32, 2170.2184375]

Epoch: 48, NLL Loss: 1463.47230468

Epoch: 86, NLL Loss: 1461.310375, Val Loss: 58438.03515625, Time took: 0.40007543563842773
Train loss Meta Info:  [1418.58849609   42.7218855 ]
Val Loss Meta Info:  [56752.24, 1685.79484375]

Epoch: 87, NLL Loss: 1461.28348046875, Val Loss: 58436.671875, Time took: 0.3983619213104248
Train loss Meta Info:  [1418.58635156   42.69713446]
Val Loss Meta Info:  [56752.28, 1684.393125]

Epoch: 88, NLL Loss: 1461.2494296875, Val Loss: 58435.02734375, Time took: 0.3967139720916748
Train loss Meta Info:  [1418.58641406   42.66301318]
Val Loss Meta Info:  [56752.27, 1682.7590625]

Epoch: 89, NLL Loss: 1461.2265078125, Val Loss: 58431.68359375, Time took: 0.3997375965118408
Train loss Meta Info:  [1418.58488672   42.64161725]
Val Loss Meta Info:  [56752.26, 1679.4259375]

Epoch: 90, NLL Loss: 1461.17796484375, Val Loss: 58428.51953125, Time took: 0.39623260498046875
Train loss Meta Info:  [1418.58432812   42.59363403]
Val Loss Meta Info:  [56752.12, 1676.3978125]

Epoch: 91, NLL Loss: 1461.140300

Epoch: 129, NLL Loss: 1460.5615625, Val Loss: 58410.0546875, Time took: 0.4033043384552002
Train loss Meta Info:  [1418.56865625   41.9929093 ]
Val Loss Meta Info:  [56752.14, 1657.9140625]

Epoch: 130, NLL Loss: 1460.70139453125, Val Loss: 58403.59765625, Time took: 0.4032909870147705
Train loss Meta Info:  [1418.55162891   42.14977258]
Val Loss Meta Info:  [56752.97, 1650.63078125]

Epoch: 131, NLL Loss: 1460.5576328125, Val Loss: 58407.015625, Time took: 0.405057430267334
Train loss Meta Info:  [1418.57099219   41.98663635]
Val Loss Meta Info:  [56752.36, 1654.6565625]

Epoch: 132, NLL Loss: 1460.6497578125, Val Loss: 58402.7578125, Time took: 0.39972686767578125
Train loss Meta Info:  [1418.55222266   42.09753125]
Val Loss Meta Info:  [56753.15, 1649.6115625]

Epoch: 133, NLL Loss: 1460.50901953125, Val Loss: 58405.6328125, Time took: 0.40331006050109863
Train loss Meta Info:  [1418.56961719   41.93940228]
Val Loss Meta Info:  [56752.595, 1653.04]

Epoch: 134, NLL Loss: 1460.565402

Epoch: 172, NLL Loss: 1460.18748046875, Val Loss: 58396.890625, Time took: 0.3963608741760254
Train loss Meta Info:  [1418.51701953   41.67045551]
Val Loss Meta Info:  [56753.52, 1643.37125]

Epoch: 173, NLL Loss: 1460.223515625, Val Loss: 58393.203125, Time took: 0.40113067626953125
Train loss Meta Info:  [1418.51868359   41.70482874]
Val Loss Meta Info:  [56753.305, 1639.90125]

Epoch: 174, NLL Loss: 1460.143515625, Val Loss: 58394.35546875, Time took: 0.3968985080718994
Train loss Meta Info:  [1418.51455469   41.62895648]
Val Loss Meta Info:  [56753.42, 1640.93640625]

Epoch: 175, NLL Loss: 1460.1832109375, Val Loss: 58391.859375, Time took: 0.3946223258972168
Train loss Meta Info:  [1418.51785937   41.6653559 ]
Val Loss Meta Info:  [56753.275, 1638.58484375]

Epoch: 176, NLL Loss: 1460.120875, Val Loss: 58392.75, Time took: 0.39504551887512207
Train loss Meta Info:  [1418.51363672   41.60724707]
Val Loss Meta Info:  [56753.43, 1639.321875]

Epoch: 177, NLL Loss: 1460.13308984375, V

## HRMTPP

In [12]:
main(model=HRMTPP, data=data, val_data=val_data)

Times: Data Shape: torch.Size([50, 8000, 2]), Val Data Shape: torch.Size([50, 2000, 2])
Markers: Data Shape: torch.Size([50, 8000, 20]), Val Data Shape: torch.Size([50, 2000, 20])
Epoch: 0, NLL Loss: 1601.81908203125, Val Loss: 65660.7734375, Time took: 0.7150976657867432
Train loss Meta Info:  [1.50098243e+03 1.00836648e+02 1.12166088e+00]
Val Loss Meta Info:  [59089.05, 5400.19125, 117.153671875]

Epoch: 1, NLL Loss: 1613.85015625, Val Loss: 64006.0390625, Time took: 0.7133045196533203
Train loss Meta Info:  [1477.14755469  136.67326685    2.93572713]
Val Loss Meta Info:  [58174.6, 4771.4565625, 105.99826171875]

Epoch: 2, NLL Loss: 1574.10191015625, Val Loss: 63836.109375, Time took: 0.7147996425628662
Train loss Meta Info:  [1453.75442969  120.29441895    2.65282704]
Val Loss Meta Info:  [57909.82, 4607.50375, 131.87875]

Epoch: 3, NLL Loss: 1563.68284375, Val Loss: 63628.61328125, Time took: 0.7186882495880127
Train loss Meta Info:  [1447.41490234  116.16899927    3.29799633]
Val 

Epoch: 36, NLL Loss: 1483.97135546875, Val Loss: 59359.9921875, Time took: 0.7569100856781006
Train loss Meta Info:  [1.41903935e+03 6.48731423e+01 1.63507867e-01]
Val Loss Meta Info:  [56768.84, 2532.1459375, 5.901016845703125]

Epoch: 37, NLL Loss: 1483.5619921875, Val Loss: 59339.078125, Time took: 0.7597320079803467
Train loss Meta Info:  [1.41904129e+03 6.44667440e+01 1.45860201e-01]
Val Loss Meta Info:  [56765.74, 2519.2709375, 5.407215576171875]

Epoch: 38, NLL Loss: 1483.1693828125, Val Loss: 59321.5625, Time took: 0.7438199520111084
Train loss Meta Info:  [1.41898652e+03 6.41320876e+01 1.33566773e-01]
Val Loss Meta Info:  [56766.64, 2504.28, 5.064647216796875]

Epoch: 39, NLL Loss: 1482.76603125, Val Loss: 59303.96875, Time took: 0.7451727390289307
Train loss Meta Info:  [1.41896883e+03 6.37483879e+01 1.25161936e-01]
Val Loss Meta Info:  [56766.505, 2489.6221875, 4.784716491699219]

Epoch: 40, NLL Loss: 1482.3961328125, Val Loss: 59287.19921875, Time took: 0.7389070987701416
T

Epoch: 72, NLL Loss: 1472.2817890625, Val Loss: 58857.9296875, Time took: 0.7538714408874512
Train loss Meta Info:  [1.41868885e+03 5.35823617e+01 1.46895776e-02]
Val Loss Meta Info:  [56755.83, 2095.5475, 0.6551890563964844]

Epoch: 73, NLL Loss: 1472.161171875, Val Loss: 58847.69921875, Time took: 0.7594311237335205
Train loss Meta Info:  [1.41869821e+03 5.34512360e+01 1.60661631e-02]
Val Loss Meta Info:  [56755.175, 2085.93875, 0.6585354614257812]

Epoch: 74, NLL Loss: 1471.910546875, Val Loss: 58835.9453125, Time took: 0.7491292953491211
Train loss Meta Info:  [1.41868139e+03 5.32170966e+01 1.62826793e-02]
Val Loss Meta Info:  [56755.17, 2074.06046875, 0.6716293334960938]

Epoch: 75, NLL Loss: 1471.61283984375, Val Loss: 58821.7421875, Time took: 0.7422051429748535
Train loss Meta Info:  [1.41868419e+03 5.29161710e+01 1.66402432e-02]
Val Loss Meta Info:  [56756.175, 2059.5253125, 0.6045537567138672]

Epoch: 76, NLL Loss: 1471.24565234375, Val Loss: 58812.2421875, Time took: 0.73867

Epoch: 108, NLL Loss: 1464.798515625, Val Loss: 58561.1796875, Time took: 0.7426764965057373
Train loss Meta Info:  [1.41862432e+03 4.61681857e+01 5.57211579e-03]
Val Loss Meta Info:  [56753.95, 1803.32, 0.39119041442871094]

Epoch: 109, NLL Loss: 1464.74552734375, Val Loss: 58552.5234375, Time took: 0.7535536289215088
Train loss Meta Info:  [1.41864985e+03 4.60849615e+01 9.82255645e-03]
Val Loss Meta Info:  [56753.285, 1796.3440625, 0.2893510437011719]

Epoch: 110, NLL Loss: 1464.525140625, Val Loss: 58547.203125, Time took: 0.7432377338409424
Train loss Meta Info:  [1.41862531e+03 4.58917554e+01 7.33092272e-03]
Val Loss Meta Info:  [56753.51, 1791.9084375, 0.17855804443359374]

Epoch: 111, NLL Loss: 1464.41723828125, Val Loss: 58541.796875, Time took: 0.745527982711792
Train loss Meta Info:  [1.41863562e+03 4.57765696e+01 4.55262494e-03]
Val Loss Meta Info:  [56753.78, 1786.2309375, 0.17918380737304687]

Epoch: 112, NLL Loss: 1464.2936171875, Val Loss: 58534.76953125, Time took: 0.75

Epoch: 144, NLL Loss: 1461.79541796875, Val Loss: 58438.93359375, Time took: 0.743781566619873
Train loss Meta Info:  [1.41861286e+03 4.31745219e+01 5.56858321e-03]
Val Loss Meta Info:  [56752.645, 1684.3475, 0.19470390319824218]

Epoch: 145, NLL Loss: 1461.68980078125, Val Loss: 58438.25, Time took: 0.7449524402618408
Train loss Meta Info:  [1.41860843e+03 4.30744684e+01 4.75681464e-03]
Val Loss Meta Info:  [56754.11, 1682.0065625, 0.2132670211791992]

Epoch: 146, NLL Loss: 1461.67071484375, Val Loss: 58437.40234375, Time took: 0.7397582530975342
Train loss Meta Info:  [1.41864138e+03 4.30217408e+01 5.20458040e-03]
Val Loss Meta Info:  [56754.19, 1681.0175, 0.21931499481201172]

Epoch: 147, NLL Loss: 1461.61662890625, Val Loss: 58432.29296875, Time took: 0.7387700080871582
Train loss Meta Info:  [1.41863033e+03 4.29783585e+01 5.40261221e-03]
Val Loss Meta Info:  [56752.85, 1677.411875, 0.20341194152832032]

Epoch: 148, NLL Loss: 1461.5138046875, Val Loss: 58431.375, Time took: 0.74262

Epoch: 180, NLL Loss: 1460.8832265625, Val Loss: 58409.515625, Time took: 0.7450413703918457
Train loss Meta Info:  [1.41859200e+03 4.22856729e+01 3.07962089e-03]
Val Loss Meta Info:  [56753.69, 1654.5390625, 0.12841489791870117]

Epoch: 181, NLL Loss: 1460.8628359375, Val Loss: 58411.5625, Time took: 0.7387077808380127
Train loss Meta Info:  [1.41861630e+03 4.22409369e+01 3.09599959e-03]
Val Loss Meta Info:  [56752.47, 1657.52953125, 0.15657829284667968]

Epoch: 182, NLL Loss: 1460.87019140625, Val Loss: 58408.25390625, Time took: 0.738818883895874
Train loss Meta Info:  [1.41858611e+03 4.22769916e+01 3.89910845e-03]
Val Loss Meta Info:  [56754.35, 1652.65953125, 0.12470794677734375]

Epoch: 183, NLL Loss: 1460.81618359375, Val Loss: 58405.25390625, Time took: 0.742990255355835
Train loss Meta Info:  [1.41863239e+03 4.21782543e+01 3.03084622e-03]
Val Loss Meta Info:  [56753.19, 1650.906875, 0.11609197616577148]

Epoch: 184, NLL Loss: 1460.72501171875, Val Loss: 58410.29296875, Time to

## rmtpp

In [13]:
main(model=rmtpp, data=data, val_data=val_data)

Times: Data Shape: torch.Size([50, 8000, 2]), Val Data Shape: torch.Size([50, 2000, 2])
Markers: Data Shape: torch.Size([50, 8000, 20]), Val Data Shape: torch.Size([50, 2000, 20])
Epoch: 0, NLL Loss: 1519.8495, Val Loss: 57723.35546875, Time took: 0.44893574714660645
Train loss Meta Info:  [1444.00798437   75.84150208]
Val Loss Meta Info:  [57307.535, 415.824921875]

Epoch: 1, NLL Loss: 1442.43723046875, Val Loss: 59497.359375, Time took: 0.42549681663513184
Train loss Meta Info:  [1432.83265625    9.6045724 ]
Val Loss Meta Info:  [57215.04, 2282.32]

Epoch: 2, NLL Loss: 1487.3058203125, Val Loss: 59576.58203125, Time took: 0.42228198051452637
Train loss Meta Info:  [1430.13851953   57.16730798]
Val Loss Meta Info:  [57271.32, 2305.2621875]

Epoch: 3, NLL Loss: 1488.50864453125, Val Loss: 59348.046875, Time took: 0.4171335697174072
Train loss Meta Info:  [1431.61246484   56.89618091]
Val Loss Meta Info:  [56879.19, 2468.86203125]

Epoch: 4, NLL Loss: 1480.43335546875, Val Loss: 59564.2

Epoch: 43, NLL Loss: 1462.82735546875, Val Loss: 134957.5625, Time took: 0.41729021072387695
Train loss Meta Info:  [1418.78383594   44.04353467]
Val Loss Meta Info:  [56759.12, 78198.45]

Epoch: 44, NLL Loss: 1463.2329609375, Val Loss: 135050.203125, Time took: 0.4096224308013916
Train loss Meta Info:  [1418.770875     44.46210297]
Val Loss Meta Info:  [56758.64, 78291.57]

Epoch: 45, NLL Loss: 1462.3774765625, Val Loss: 144228.609375, Time took: 0.41159939765930176
Train loss Meta Info:  [1418.76211719   43.61536249]
Val Loss Meta Info:  [56759.435, 87469.14]

Epoch: 46, NLL Loss: 1463.23995703125, Val Loss: 124219.4296875, Time took: 0.4104886054992676
Train loss Meta Info:  [1418.77313281   44.4668443 ]
Val Loss Meta Info:  [56758.155, 67461.28]

Epoch: 47, NLL Loss: 1462.88539453125, Val Loss: 125056.6953125, Time took: 0.41052818298339844
Train loss Meta Info:  [1418.75455469   44.13084631]
Val Loss Meta Info:  [56759.415, 68297.29]

Epoch: 48, NLL Loss: 1462.38334765625, Val Los

Epoch: 87, NLL Loss: 1461.8253203125, Val Loss: 97263.9140625, Time took: 0.4161362648010254
Train loss Meta Info:  [1418.61060547   43.21470374]
Val Loss Meta Info:  [56755.32, 40508.6]

Epoch: 88, NLL Loss: 1461.902234375, Val Loss: 110024.8984375, Time took: 0.41614675521850586
Train loss Meta Info:  [1418.61696484   43.28526318]
Val Loss Meta Info:  [56754.54, 53270.365]

Epoch: 89, NLL Loss: 1461.44580859375, Val Loss: 106974.078125, Time took: 0.4220263957977295
Train loss Meta Info:  [1418.60521484   42.84059485]
Val Loss Meta Info:  [56754.555, 50219.525]

Epoch: 90, NLL Loss: 1461.7337265625, Val Loss: 101962.234375, Time took: 0.4188528060913086
Train loss Meta Info:  [1418.61328125   43.12045563]
Val Loss Meta Info:  [56754.65, 45207.595]

Epoch: 91, NLL Loss: 1461.4495625, Val Loss: 98897.9296875, Time took: 0.42453598976135254
Train loss Meta Info:  [1418.61021094   42.83934485]
Val Loss Meta Info:  [56754.57, 42143.365]

Epoch: 92, NLL Loss: 1461.61555859375, Val Loss: 10

Epoch: 130, NLL Loss: 1460.6780625, Val Loss: 93300.9140625, Time took: 0.42639994621276855
Train loss Meta Info:  [1418.51830469   42.15975555]
Val Loss Meta Info:  [56754.23, 36546.685]

Epoch: 131, NLL Loss: 1461.11359765625, Val Loss: 95401.4609375, Time took: 0.41133594512939453
Train loss Meta Info:  [1418.51592578   42.59768646]
Val Loss Meta Info:  [56754.4, 38647.065]

Epoch: 132, NLL Loss: 1460.64925, Val Loss: 99029.0390625, Time took: 0.41181230545043945
Train loss Meta Info:  [1418.51773047   42.13153912]
Val Loss Meta Info:  [56754.72, 42274.31]

Epoch: 133, NLL Loss: 1460.97366796875, Val Loss: 97087.125, Time took: 0.41220974922180176
Train loss Meta Info:  [1418.524        42.44965088]
Val Loss Meta Info:  [56754.545, 40332.585]

Epoch: 134, NLL Loss: 1460.47495703125, Val Loss: 87549.6015625, Time took: 0.4136812686920166
Train loss Meta Info:  [1418.52108203   41.95388385]
Val Loss Meta Info:  [56754.425, 30795.17]

Epoch: 135, NLL Loss: 1460.8127421875, Val Loss: 92

Epoch: 173, NLL Loss: 1459.9007578125, Val Loss: 87331.2265625, Time took: 0.42354869842529297
Train loss Meta Info:  [1418.44944141   41.45131122]
Val Loss Meta Info:  [56756.355, 30574.875]

Epoch: 174, NLL Loss: 1460.0692265625, Val Loss: 102931.359375, Time took: 0.4225592613220215
Train loss Meta Info:  [1418.44421875   41.62498981]
Val Loss Meta Info:  [56756.73, 46174.61]

Epoch: 175, NLL Loss: 1460.17372265625, Val Loss: 88596.546875, Time took: 0.42085766792297363
Train loss Meta Info:  [1418.44694922   41.72678265]
Val Loss Meta Info:  [56756.83, 31839.7225]

Epoch: 176, NLL Loss: 1460.03976171875, Val Loss: 92093.109375, Time took: 0.41774749755859375
Train loss Meta Info:  [1418.44430469   41.59547321]
Val Loss Meta Info:  [56756.75, 35336.35]

Epoch: 177, NLL Loss: 1460.0770546875, Val Loss: 91416.609375, Time took: 0.4291188716888428
Train loss Meta Info:  [1418.44056641   41.63649988]
Val Loss Meta Info:  [56756.66, 34659.9475]

Epoch: 178, NLL Loss: 1459.87686328125, Va

## hrmtpp

In [14]:
main(model=hrmtpp, data=data, val_data=val_data)

Times: Data Shape: torch.Size([50, 8000, 2]), Val Data Shape: torch.Size([50, 2000, 2])
Markers: Data Shape: torch.Size([50, 8000, 20]), Val Data Shape: torch.Size([50, 2000, 20])
Epoch: 0, NLL Loss: 1516.4166796875, Val Loss: 62809.91015625, Time took: 0.8191773891448975
Train loss Meta Info:  [1.45660838e+03 5.97120386e+01 9.62728846e-02]
Val Loss Meta Info:  [61924.735, 865.41125, 19.767064208984376]

Epoch: 1, NLL Loss: 1570.9606171875, Val Loss: 63048.6171875, Time took: 0.7823503017425537
Train loss Meta Info:  [1.54753283e+03 2.29328391e+01 4.94971273e-01]
Val Loss Meta Info:  [60706.565, 2333.8765625, 8.181640625]

Epoch: 2, NLL Loss: 1576.44079296875, Val Loss: 61415.35546875, Time took: 0.7858977317810059
Train loss Meta Info:  [1.51797221e+03 5.82640049e+01 2.04570010e-01]
Val Loss Meta Info:  [59194.1, 2190.0465625, 31.204609375]

Epoch: 3, NLL Loss: 1534.813140625, Val Loss: 60538.9140625, Time took: 0.7918968200683594
Train loss Meta Info:  [1.47959564e+03 5.44361400e+01 

Epoch: 36, NLL Loss: 1472.6558984375, Val Loss: 233686.78125, Time took: 0.912043571472168
Train loss Meta Info:  [1.42062735e+03 5.19639509e+01 6.46149240e-02]
Val Loss Meta Info:  [56821.565, 176862.7, 2.512484130859375]

Epoch: 37, NLL Loss: 1469.76530859375, Val Loss: 303803.25, Time took: 0.9261817932128906
Train loss Meta Info:  [1.42019372e+03 4.95087677e+01 6.28020414e-02]
Val Loss Meta Info:  [56832.74, 246968.12, 2.417215118408203]

Epoch: 38, NLL Loss: 1476.67630859375, Val Loss: 233066.015625, Time took: 0.8218085765838623
Train loss Meta Info:  [1.42049880e+03 5.61170577e+01 6.04277351e-02]
Val Loss Meta Info:  [56828.6, 176235.14, 2.3116915893554686]

Epoch: 39, NLL Loss: 1466.91131640625, Val Loss: 177829.15625, Time took: 0.8085103034973145
Train loss Meta Info:  [1.42040465e+03 4.64488681e+01 5.77913826e-02]
Val Loss Meta Info:  [56814.39, 121012.58, 2.2045576477050783]

Epoch: 40, NLL Loss: 1469.13708984375, Val Loss: 174373.59375, Time took: 0.7892203330993652
Train 

Epoch: 73, NLL Loss: 1462.6758125, Val Loss: 151667.25, Time took: 0.7836148738861084
Train loss Meta Info:  [1.41889889e+03 4.37725499e+01 4.42096000e-03]
Val Loss Meta Info:  [56774.05, 94893.01, 0.1654181671142578]

Epoch: 74, NLL Loss: 1462.94342578125, Val Loss: 171977.875, Time took: 0.8017199039459229
Train loss Meta Info:  [1.41897706e+03 4.39622522e+01 4.13953975e-03]
Val Loss Meta Info:  [56768.995, 115208.74, 0.15500106811523437]

Epoch: 75, NLL Loss: 1462.580546875, Val Loss: 164839.5625, Time took: 0.8095505237579346
Train loss Meta Info:  [1.41883827e+03 4.37383844e+01 3.87902986e-03]
Val Loss Meta Info:  [56772.385, 108067.03, 0.14542343139648437]

Epoch: 76, NLL Loss: 1462.69693359375, Val Loss: 139525.171875, Time took: 0.8025224208831787
Train loss Meta Info:  [1.41891356e+03 4.37798141e+01 3.63939515e-03]
Val Loss Meta Info:  [56767.475, 82757.56, 0.13668516159057617]

Epoch: 77, NLL Loss: 1462.35306640625, Val Loss: 135727.609375, Time took: 0.8195743560791016
Train

Epoch: 109, NLL Loss: 1462.52032421875, Val Loss: 113844.515625, Time took: 0.8265259265899658
Train loss Meta Info:  [1.41863086e+03 4.38888154e+01 5.70653412e-04]
Val Loss Meta Info:  [56766.93, 57077.56, 0.02152087688446045]

Epoch: 110, NLL Loss: 1462.68553515625, Val Loss: 98876.3359375, Time took: 0.8267881870269775
Train loss Meta Info:  [1.41869123e+03 4.39938277e+01 5.40515340e-04]
Val Loss Meta Info:  [56763.47, 42112.845, 0.020397939682006837]

Epoch: 111, NLL Loss: 1462.720203125, Val Loss: 108804.890625, Time took: 0.7865138053894043
Train loss Meta Info:  [1.41861566e+03 4.41040575e+01 5.12210951e-04]
Val Loss Meta Info:  [56761.98, 52042.89, 0.019310375452041628]

Epoch: 112, NLL Loss: 1461.52734375, Val Loss: 127624.4375, Time took: 0.7904365062713623
Train loss Meta Info:  [1.41857416e+03 4.29526973e+01 4.85025334e-04]
Val Loss Meta Info:  [56764.8, 70859.625, 0.018296420574188232]

Epoch: 113, NLL Loss: 1464.18470703125, Val Loss: 104005.890625, Time took: 0.803702831

Epoch: 145, NLL Loss: 1461.4047421875, Val Loss: 91369.0, Time took: 0.8151271343231201
Train loss Meta Info:  [1.41851662e+03 4.28880074e+01 9.27737650e-05]
Val Loss Meta Info:  [56764.09, 34604.905, 0.0033290630578994753]

Epoch: 146, NLL Loss: 1461.3312265625, Val Loss: 111276.2578125, Time took: 0.8181271553039551
Train loss Meta Info:  [1.41852264e+03 4.28085059e+01 8.79193179e-05]
Val Loss Meta Info:  [56763.08, 54513.18, 0.0031751957535743713]

Epoch: 147, NLL Loss: 1460.66658203125, Val Loss: 109094.140625, Time took: 0.8416290283203125
Train loss Meta Info:  [1.41847478e+03 4.21917319e+01 8.38501276e-05]
Val Loss Meta Info:  [56764.61, 52329.535, 0.003019447922706604]

Epoch: 148, NLL Loss: 1460.6030703125, Val Loss: 92733.328125, Time took: 0.8080394268035889
Train loss Meta Info:  [1.41850906e+03 4.20939410e+01 7.95103858e-05]
Val Loss Meta Info:  [56762.94, 35970.385, 0.0028577205538749696]

Epoch: 149, NLL Loss: 1460.91134765625, Val Loss: 124015.25, Time took: 0.821834802

Epoch: 181, NLL Loss: 1460.1328828125, Val Loss: 105614.28125, Time took: 0.8182809352874756
Train loss Meta Info:  [1.41835580e+03 4.17770632e+01 2.25267973e-05]
Val Loss Meta Info:  [56764.64, 48849.66, 0.0007879181206226349]

Epoch: 182, NLL Loss: 1460.454734375, Val Loss: 84346.578125, Time took: 0.8206558227539062
Train loss Meta Info:  [1.41836198e+03 4.20927329e+01 2.22260865e-05]
Val Loss Meta Info:  [56764.27, 27582.305, 0.0007685727626085281]

Epoch: 183, NLL Loss: 1460.49633984375, Val Loss: 103949.5546875, Time took: 0.7859094142913818
Train loss Meta Info:  [1.41835362e+03 4.21427414e+01 2.16277951e-05]
Val Loss Meta Info:  [56764.09, 47185.465, 0.0007667206972837448]

Epoch: 184, NLL Loss: 1460.11751953125, Val Loss: 95732.8984375, Time took: 0.787747859954834
Train loss Meta Info:  [1.41834431e+03 4.17732163e+01 2.16699491e-05]
Val Loss Meta Info:  [56764.13, 38968.76, 0.000762847512960434]

Epoch: 185, NLL Loss: 1460.01125, Val Loss: 84505.3203125, Time took: 0.81962084