In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Normal
from torch.distributions import kl_divergence

In [2]:
torch.__version__

'1.0.1'

In [3]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


# Define RMTPP Module

In [24]:
class RMTPP(nn.Module):
    def __init__(self, marker_type='real', marker_dim=32):
        super().__init__()
        
        self.marker_dim = marker_dim
        
        # Dimensions for embedding inputs
        self.marker_embed_dim = 64
        self.time_embed_dim = 64
        # Networks for embedding inputs
        self.create_embedding_nets()
        
        # This is the layer that encodes the history
        self.hidden_layer_dim = 128
        # Create RNN layer Network
        self.create_rnn()
        
        # Hidden shared layer size (bw mu and var)
        # while generating marker from hidden state
        self.marker_shared_dim = 64
        # Create Network for Marker generation from hidden_seq
        self.create_marker_generation_net()
        self.create_time_likelihood_net()

    ############ UTILITY METHODS #############
        
    def _one_hot_marker(self, marker_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx1
        """
        pass


    
    ############ NETWORKS #############    
    
    def create_rnn(self):
        """
            Input:
            marker_embed_dim: dimension of embedded markers
            time_embed_dim: dimension of embedded times,intervals
            hidden_layer_dim: dimension of hidden state of recurrent layer
        """

        ### Vanilla RNN
#         self.rnn = nn.RNN(
#             input_size=self.marker_embed_dim+self.time_embed_dim,
#             hidden_size=self.hidden_layer_dim,
#             nonlinearity='relu'
#         )
        ### GRU
        self.rnn = nn.GRU(
            input_size=self.marker_embed_dim+self.time_embed_dim,
            hidden_size=self.hidden_layer_dim,
        )
    
    def create_embedding_nets(self):
        # marker_dim is passed. timeseries_dim is 2
        self.marker_embedding_net = nn.Linear(self.marker_dim, self.marker_embed_dim)
        
        self.time_embedding_net = nn.Sequential(
            nn.Linear(2, self.time_embed_dim),
            nn.ReLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )
        
    def create_marker_generation_net(self):
        """
            Generate network to create marker sufficient statistics using
            rnn's hidden layer
        """
        self.marker_gen_hidden = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.marker_shared_dim),
            nn.ReLU()
        )
        self.generated_marker_mu = nn.Linear(self.marker_shared_dim, self.marker_dim)
        self.generated_marker_var = nn.Sequential(
            nn.Linear(self.marker_shared_dim, self.marker_dim),
            nn.Softplus()
        )
    
    def create_time_likelihood_net(self):
        self.h_influence = nn.Linear(self.hidden_layer_dim, 1, bias=False)
        self.base_intensity = nn.Parameter(torch.zeros(1,1,1))
        self.wt = nn.Parameter(torch.ones(1,1,1))
    
    
    ############ METHODS #############
    
    def _embed_data(self, marker_seq, time_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx marker_dim
            time_seq: Tensor of shape TxBSx 2
            Output:
            marker_seq_emb: Tensor of shape T x BS x marker_embed_dim
            time_seq_emb: Tensor of shape T x BS x time_embed_dim
        """
        marker_seq_emb = self.marker_embedding_net(marker_seq)
        time_seq_emb = self.time_embedding_net(time_seq)
        return marker_seq_emb, time_seq_emb
    
    
    
    def marker_log_likelihood(self, h_seq, marker_seq):
        """
        Use the h_seq to generate the distribution for the markers,
        and use that distribution to compute the log likelihood of the marker_seq
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    marker_seq : Tensor of shape T x BS x marker_dim
            Output:
                    log_likelihood_marker_seq : T x BS x marker_dim
        """
        marker_gen_shared = self.marker_gen_hidden(h_seq)
        mu, var = self.generated_marker_mu(marker_gen_shared), self.generated_marker_var(marker_gen_shared)
        
        marker_gen_dist = Normal(mu, var.sqrt())
        log_likelihood_marker_seq = marker_gen_dist.log_prob(marker_seq)
        
        return log_likelihood_marker_seq
        
        
    def time_log_likelihood(self, h_seq, time_seq):
        """
        Use the h_seq to compute the log likelihood of the time_seq
        using the formula in the paper
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    time_seq : Tensor of shape T x BS x 2 . Last dimension is [times, intervals]
            Output:
                    log_likelihood_time_seq : T x BS x 1
        """
#         import pdb; pdb.set_trace()
        past_influence = self.h_influence(h_seq)
        current_influence = self.wt * time_seq[:,:,1:2]
        base_intensity = self.base_intensity
        
        term1 = past_influence + current_influence + base_intensity
        term2 = past_influence + base_intensity
        
        # After factorizing the formula in the paper
        log_likelihood_time_seq = term1 + (term2.exp() - term1.exp())/self.wt
        
        return log_likelihood_time_seq
    
    #############################################
    
    def forward(self, marker_seq, time_seq, **kwargs):
        # Transform markers and timesteps into the embedding spaces
        marker_seq_emb, time_seq_emb = self._embed_data(marker_seq, time_seq)
        T,BS,_ = marker_seq_emb.shape
        
        # Run RNN over the concatenated sequence [marker_seq_emb, time_seq_emb]
        time_marker_combined = torch.cat([marker_seq_emb, time_seq_emb], dim=-1)
        h_0 = torch.zeros(1, BS, self.hidden_layer_dim).to(device)
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        hidden_combined = torch.cat([h_0, hidden_seq], dim=0)
        
        # compute the marker and time log likelihoods
        # h_0 is used to generate marker_1 and so on...
        marker_ll = self.marker_log_likelihood(hidden_combined[:-1], marker_seq)
        # h_0 is used to generate timestamp_1 and so on...
        time_ll = self.time_log_likelihood(hidden_combined[:-1], time_seq)
        
        likelihood_loss = marker_ll.sum() + time_ll.sum()
        NLL = -likelihood_loss
        
        # NLL is used for optimization,
        # the individual LL values are used for logging
        return NLL, [-marker_ll.sum().item(), -time_ll.sum().item() ]
        

In [25]:
class HRMTPP(RMTPP):
    def __init__(self, latent_dim=20, **kwargs):
        self.latent_dim = latent_dim
        super().__init__(**kwargs)
        

        
        self.create_inference_net()
        self.create_marker_generation_net()
        self.create_time_likelihood_net()
    
    ## Utility Methods ##
    def _reparameterize(self, mu, var):
        epsilon = torch.randn_like(mu)
        return mu + epsilon*var.sqrt()
    
    def create_inference_net(self):
        self.inference_rnn = nn.GRU(
            input_size = self.marker_embed_dim+self.time_embed_dim,
            hidden_size = self.hidden_layer_dim
        )
        
        self.inference_intermediate_net = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.hidden_layer_dim),
            nn.ReLU(),
        )
        self.posterior_mean_net = nn.Linear(self.hidden_layer_dim, self.latent_dim)
        self.posterior_var_net = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.latent_dim),
            nn.Softplus()
        )
        
    def create_marker_generation_net(self):
        """
            Generate network to create marker sufficient statistics using
            rnn's hidden layer and latent variable
        """
        self.marker_gen_hidden = nn.Sequential(
            nn.Linear(self.hidden_layer_dim+self.latent_dim, self.marker_shared_dim),
            nn.ReLU()
        )
        self.generated_marker_mu = nn.Linear(self.marker_shared_dim, self.marker_dim)
        self.generated_marker_var = nn.Sequential(
            nn.Linear(self.marker_shared_dim, self.marker_dim),
            nn.Softplus()
        )
    
    def create_time_likelihood_net(self):
        self.h_influence = nn.Linear(self.hidden_layer_dim+self.latent_dim, 1, bias=False)
        self.base_intensity = nn.Parameter(torch.zeros(1,1,1))
        self.wt = nn.Parameter(torch.randn(1,1,1))
    
        
    def _inference(self, hidden_seq):
        """
        Use the hidden_seq to compute the posterior
        q(z | x) = NN(RNN(x(1...T))) = NN(h_T)
        Also, compute a sampled value and return
            Input:  
                    hidden_seq   : Tensor of shape (T+1) x BS x hidden_layer_dim (first timestep is h_0)
            Output:
                    z_sampled: Tensor of shape 1 x BS x latent_dim
                    z_mean: Tensor of shape 1 x BS x latent_dim
                    z_var: Tensor of shape 1 x BS x latent_dim
        """
        
        intermediate_layer = self.inference_intermediate_net(hidden_seq[-1:])
        z_mean = self.posterior_mean_net(intermediate_layer)
        z_var = self.posterior_var_net(intermediate_layer)
        z_sampled = self._reparameterize(z_mean, z_var)
        
        return z_sampled, z_mean, z_var
    
    def hz_combined(self, hidden_seq, z):
        """
        Concatenate the z to the hidden_seq along the last dimension
            Input:  
                    hidden_seq   : Tensor of shape (T+1) x BS x hidden_layer_dim (first timestep is h_0)
                    z: Tensor of shape 1 x BS x latent_dim
            Output:
                    hz = Tensor of shape (T+1) x BS x (hidden_layer_dim+latent_dim)
        """
        # Extract timelength
        T, _, _ = hidden_seq.shape
        # Expand z on the time dimension, leaving everything else the same
        z_broadcast = z.expand(T, -1, -1)
        
        hz = torch.cat([hidden_seq, z_broadcast], dim=-1)
        
        return hz
    
    #############################################
    
    def forward(self, marker_seq, time_seq, anneal=1.):
        # Transform markers and timesteps into the embedding spaces
        marker_seq_emb, time_seq_emb = self._embed_data(marker_seq, time_seq)
        T,BS,_ = marker_seq_emb.shape
        
        # Run RNN over the concatenated sequence [marker_seq_emb, time_seq_emb]
        time_marker_combined = torch.cat([marker_seq_emb, time_seq_emb], dim=-1)
        h_0 = torch.zeros(1, BS, self.hidden_layer_dim).to(device)
        # Run RNN
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        # Append h_0 to h_1 .. h_T
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        ## Inference
        # Get the sampled value and (mean + var) latent variable
        # using the hidden state sequence
        posterior_sample, posterior_mean, posterior_var = self._inference(hidden_seq)
        posterior_dist = Normal(posterior_mean, posterior_var.sqrt())

        # Prior is just a Normal(0,1) dist
        prior_dist = Normal(0,1)

        ## Generative Part
        
        # Use the embedded markers and times to create another set of 
        # hidden vectors. Can reuse the h_0 and time_marker combined computed above

        # Use an RNN to summarize the x1 ... xT sequence
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        # Combine hidden_seq and z to form the input for the generative part
        hz_combined = self.hz_combined(hidden_seq, posterior_sample)
        
        # compute the marker and time log likelihoods
        # z,h_0 is used to generate marker_1 and so on...
        marker_ll = self.marker_log_likelihood(hz_combined[:-1], marker_seq)
        # z,h_0 is used to generate timestamp_1 and so on...
        time_ll = self.time_log_likelihood(hz_combined[:-1], time_seq)
        
        likelihood_loss = marker_ll.sum() + time_ll.sum()
        NLL = -likelihood_loss
        
        KL = kl_divergence(posterior_dist, prior_dist).sum()
        
        loss = NLL + anneal*10*KL
        
        # NLL and KL are used for optimization,
        # the individual LL values are used for logging
        return loss, [-marker_ll.sum().item(), -time_ll.sum().item(), KL.item()]
        

# Training

In [16]:
from trainer import train
# from hrmtpp import hrmtpp
# from rmtpp import rmtpp

In [17]:
def trainer(model, data = None, val_data=None, lr= 1e-2, l2_reg=1e-2, epoch = 200, batch_size = 32):
    if data == None:
        data, val_data = generate_mpp()

    optimizer = Adam(model.parameters(), lr=lr, weight_decay=l2_reg)

    for epoch_number in range(epoch):
        train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model

In [18]:
def main(model, data, val_data):
    model = model().to(device)
#     data, _ = generate_mpp(type='hawkes', num_sample=1000)
#     val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
    print("Times: Data Shape: {}, Val Data Shape: {}".format(data['t'].shape, val_data['t'].shape))
    print("Markers: Data Shape: {}, Val Data Shape: {}".format(data['x'].shape, val_data['x'].shape))
    trainer(model, data=data, val_data=val_data)

## Cache data

In [19]:
from utils.synthetic_data import test_val_split, generate_mpp
from utils.mimic_data_tensors import mimic_data_tensors

In [20]:
data = mimic_data_tensors()
data, val_data = test_val_split(data, val_ratio=0.2)

In [21]:
# data, _ = generate_mpp(type='hawkes', num_sample=1000)
# val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
# data, _ = generate_mpp(type='autoregressive', time_step=50, num_sample=100, num_clusters=10, m=8)
# data, val_data = test_val_split(data, val_ratio=.2)

In [22]:
data['x'].shape, data['t'].shape

(torch.Size([37, 1268, 32]), torch.Size([37, 1268, 2]))

# Experimental results

## RMTPP

In [26]:
main(model=RMTPP, data=data, val_data=val_data)

Times: Data Shape: torch.Size([37, 1268, 2]), Val Data Shape: torch.Size([37, 317, 2])
Markers: Data Shape: torch.Size([37, 1268, 32]), Val Data Shape: torch.Size([37, 317, 32])
Epoch: 0, NLL Loss: 4492099781.047318, Val Loss: 21996843008.0, Time took: 0.34491729736328125
Train loss Meta Info:  [ 4.49209978e+09 -8.90796882e-01]
Val Loss Meta Info:  [21996843118.7027, -224.17884290540542]

Epoch: 1, NLL Loss: 2474084121.0347004, Val Loss: 15251885056.0, Time took: 0.29399776458740234
Train loss Meta Info:  [ 2.47408412e+09 -2.61504056e+01]
Val Loss Meta Info:  [15251885083.675676, -677.379856418919]

Epoch: 2, NLL Loss: 1740661559.7223976, Val Loss: 12580389888.0, Time took: 0.2978246212005615
Train loss Meta Info:  [ 1.74066167e+09 -7.92468317e+01]
Val Loss Meta Info:  [12580390635.243244, -1247.790962837838]

Epoch: 3, NLL Loss: 1448588902.5615141, Val Loss: 11147462656.0, Time took: 0.2848541736602783
Train loss Meta Info:  [ 1.44858903e+09 -1.46089554e+02]
Val Loss Meta Info:  [1114

Epoch: 37, NLL Loss: -7003981820580964.0, Val Loss: -1.716853047033856e+17, Time took: 0.22749114036560059
Train loss Meta Info:  [ 2.63011042e+08 -7.00398202e+15]
Val Loss Meta Info:  [2235944627.891892, -1.716853084179519e+17]

Epoch: 38, NLL Loss: -2.089652756094268e+16, Val Loss: -1.7221635163973222e+17, Time took: 0.23256993293762207
Train loss Meta Info:  [ 2.61542663e+08 -2.08965276e+16]
Val Loss Meta Info:  [2242853140.756757, -1.7221634281763725e+17]

Epoch: 39, NLL Loss: -2.0963144861193576e+16, Val Loss: -1.728920874343465e+17, Time took: 0.23165631294250488
Train loss Meta Info:  [ 2.61758426e+08 -2.09631449e+16]
Val Loss Meta Info:  [2232021545.5135136, -1.7289208186249702e+17]

Epoch: 40, NLL Loss: -2.1048747583394296e+16, Val Loss: -1.7394139948436685e+17, Time took: 0.2338259220123291
Train loss Meta Info:  [ 2.60086454e+08 -2.10487476e+16]
Val Loss Meta Info:  [2204198856.6486487, -1.739414022702916e+17]

Epoch: 41, NLL Loss: -2.1148271256565636e+16, Val Loss: -1.75988

Epoch: 73, NLL Loss: -7.460797445012366e+16, Val Loss: -6.375831534473052e+17, Time took: 0.23341774940490723
Train loss Meta Info:  [ 1.30404345e+08 -7.46079745e+16]
Val Loss Meta Info:  [1095323869.4054055, -6.3758313858904e+17]

Epoch: 74, NLL Loss: -7.460467548167898e+16, Val Loss: -6.375823288135844e+17, Time took: 0.2326946258544922
Train loss Meta Info:  [ 1.28825379e+08 -7.46046755e+16]
Val Loss Meta Info:  [1082122018.5945945, -6.375823065261865e+17]

Epoch: 75, NLL Loss: -7.459490387558006e+16, Val Loss: -6.375823288135844e+17, Time took: 0.2324821949005127
Train loss Meta Info:  [ 1.27274022e+08 -7.45949039e+16]
Val Loss Meta Info:  [1069381576.6486486, -6.375823065261865e+17]

Epoch: 76, NLL Loss: -7.459547509268165e+16, Val Loss: -6.375821913746309e+17, Time took: 0.22575950622558594
Train loss Meta Info:  [ 1.25779001e+08 -7.45954751e+16]
Val Loss Meta Info:  [1057030365.4054054, -6.375821876600646e+17]

Epoch: 77, NLL Loss: -7.459452906177794e+16, Val Loss: -6.3758219137

Epoch: 109, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21796965599060059
Train loss Meta Info:  [ 9.24489214e+07 -7.46182197e+16]
Val Loss Meta Info:  [779670555.6756756, -6.375821876600646e+17]

Epoch: 110, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21327424049377441
Train loss Meta Info:  [ 9.17481048e+07 -7.46182197e+16]
Val Loss Meta Info:  [773811781.1891892, -6.375821876600646e+17]

Epoch: 111, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21297311782836914
Train loss Meta Info:  [ 9.10590253e+07 -7.46182197e+16]
Val Loss Meta Info:  [768049594.8108108, -6.375821876600646e+17]

Epoch: 112, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21341633796691895
Train loss Meta Info:  [ 9.03813184e+07 -7.46182197e+16]
Val Loss Meta Info:  [762382169.945946, -6.375821876600646e+17]

Epoch: 113, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821

Epoch: 145, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21096372604370117
Train loss Meta Info:  [ 7.29268959e+07 -7.46182197e+16]
Val Loss Meta Info:  [616131390.2702702, -6.375821876600646e+17]

Epoch: 146, NLL Loss: -7.461821993879717e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21141648292541504
Train loss Meta Info:  [ 7.25125545e+07 -7.46182199e+16]
Val Loss Meta Info:  [612653111.3513514, -6.375821876600646e+17]

Epoch: 147, NLL Loss: -7.461822015557784e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21100783348083496
Train loss Meta Info:  [ 7.21033864e+07 -7.46182202e+16]
Val Loss Meta Info:  [609218449.2972972, -6.375821876600646e+17]

Epoch: 148, NLL Loss: -7.461822080591989e+16, Val Loss: -6.375821913746309e+17, Time took: 0.21153831481933594
Train loss Meta Info:  [ 7.16993711e+07 -7.46182208e+16]
Val Loss Meta Info:  [605826407.7837838, -6.375821876600646e+17]

Epoch: 149, NLL Loss: -7.461822188982331e+16, Val Loss: -6.37582

Epoch: 181, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.2409355640411377
Train loss Meta Info:  [ 6.07118428e+07 -7.46182197e+16]
Val Loss Meta Info:  [513483969.7297297, -6.375821876600646e+17]

Epoch: 182, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.24295878410339355
Train loss Meta Info:  [ 6.04366415e+07 -7.46182197e+16]
Val Loss Meta Info:  [511187497.5135135, -6.375821876600646e+17]

Epoch: 183, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.2458193302154541
Train loss Meta Info:  [ 6.01664368e+07 -7.46182197e+16]
Val Loss Meta Info:  [508883829.6216216, -6.375821876600646e+17]

Epoch: 184, NLL Loss: -7.461821972201648e+16, Val Loss: -6.375821913746309e+17, Time took: 0.24047136306762695
Train loss Meta Info:  [ 5.98953937e+07 -7.46182197e+16]
Val Loss Meta Info:  [506611435.2432432, -6.375821876600646e+17]

Epoch: 185, NLL Loss: -7.461821972201648e+16, Val Loss: -6.3758219

## HRMTPP

In [27]:
main(model=HRMTPP, data=data, val_data=val_data)

Times: Data Shape: torch.Size([37, 1268, 2]), Val Data Shape: torch.Size([37, 317, 2])
Markers: Data Shape: torch.Size([37, 1268, 32]), Val Data Shape: torch.Size([37, 317, 32])
Epoch: 0, NLL Loss: 4872830145.817035, Val Loss: 18511527936.0, Time took: 0.4079751968383789
Train loss Meta Info:  [4.87283015e+09 1.26922887e+00 6.12307640e-01]
Val Loss Meta Info:  [18511527050.37838, -50.07133195206926, 7.438031170819257]

Epoch: 1, NLL Loss: 2090164325.7539432, Val Loss: 9887761408.0, Time took: 0.3835105895996094
Train loss Meta Info:  [ 2.09016433e+09 -5.81910316e+00  8.67580469e-01]
Val Loss Meta Info:  [9887759941.18919, -49.544618348817565, 107.32084037162163]

Epoch: 2, NLL Loss: 1135579965.3753943, Val Loss: 4026192640.0, Time took: 0.39615440368652344
Train loss Meta Info:  [ 1.13557997e+09 -4.61438786e+00  1.25672726e+01]
Val Loss Meta Info:  [4026184676.324324, -181.4596706081081, 777.5174725506756]

Epoch: 3, NLL Loss: 451795458.01892745, Val Loss: 1908115840.0, Time took: 0.38

Epoch: 31, NLL Loss: -4.671873479167338e+18, Val Loss: -3.6828847858822152e+19, Time took: 0.3783395290374756
Train loss Meta Info:  [ 1.39612628e+07 -4.67187348e+18  9.25954356e+03]
Val Loss Meta Info:  [123692751.56756757, -3.682884779938909e+19, 80645.01351351352]

Epoch: 32, NLL Loss: -4.3515548828323026e+18, Val Loss: -3.716851118989797e+19, Time took: 0.4096670150756836
Train loss Meta Info:  [ 1.37244808e+07 -4.35155488e+18  9.43002073e+03]
Val Loss Meta Info:  [121741007.56756757, -3.716850821824492e+19, 82151.24324324324]

Epoch: 33, NLL Loss: -5.050017178279948e+18, Val Loss: -3.913759137970992e+19, Time took: 0.376331090927124
Train loss Meta Info:  [ 1.34971515e+07 -5.05001718e+18  9.60704150e+03]
Val Loss Meta Info:  [119222105.94594595, -3.913758924011972e+19, 83332.80405405405]

Epoch: 34, NLL Loss: -4.2734227714671693e+18, Val Loss: -3.2955470704306815e+19, Time took: 0.3898346424102783
Train loss Meta Info:  [ 1.32321720e+07 -4.27342277e+18  9.74676974e+03]
Val Loss Me

Epoch: 62, NLL Loss: -5.397895963120762e+18, Val Loss: -5.043830152124747e+19, Time took: 0.35905933380126953
Train loss Meta Info:  [ 9.79387322e+06 -5.39789596e+18  1.33066270e+04]
Val Loss Meta Info:  [87163281.2972973, -5.043829985712176e+19, 114791.3918918919]

Epoch: 63, NLL Loss: -5.051538482685387e+18, Val Loss: -5.508151273706645e+19, Time took: 0.3579444885253906
Train loss Meta Info:  [ 9.64862013e+06 -5.05153848e+18  1.34287277e+04]
Val Loss Meta Info:  [85892400.43243243, -5.5081511072940745e+19, 115845.25675675676]

Epoch: 64, NLL Loss: -5.215434748599177e+18, Val Loss: -4.957073407037407e+19, Time took: 0.3403048515319824
Train loss Meta Info:  [ 9.59150596e+06 -5.21543475e+18  1.35514299e+04]
Val Loss Meta Info:  [85692685.83783785, -4.95707331194451e+19, 116900.1081081081]

Epoch: 65, NLL Loss: -5.771758052963102e+18, Val Loss: -4.684281493553624e+19, Time took: 0.3378565311431885
Train loss Meta Info:  [ 9.47954069e+06 -5.77175805e+18  1.36744366e+04]
Val Loss Meta In

Epoch: 93, NLL Loss: -6.725872989315882e+18, Val Loss: -6.28100767257272e+19, Time took: 0.40389180183410645
Train loss Meta Info:  [ 7.64813799e+06 -6.72587299e+18  1.66400908e+04]
Val Loss Meta Info:  [68656003.45945945, -6.281007601253047e+19, 143052.02702702704]

Epoch: 94, NLL Loss: -7.073936311748634e+18, Val Loss: -6.057273329092395e+19, Time took: 0.3958902359008789
Train loss Meta Info:  [ 7.58647401e+06 -7.07393631e+18  1.67339219e+04]
Val Loss Meta Info:  [68330876.54054055, -6.05727303192709e+19, 143856.16216216216]

Epoch: 95, NLL Loss: -6.877540774790296e+18, Val Loss: -6.1137965830484525e+19, Time took: 0.3834645748138428
Train loss Meta Info:  [ 7.51150638e+06 -6.87754077e+18  1.68281278e+04]
Val Loss Meta Info:  [67841024.0, -6.113796630594902e+19, 144663.7027027027]

Epoch: 96, NLL Loss: -7.645281691992802e+18, Val Loss: -5.853234317498096e+19, Time took: 0.3840293884277344
Train loss Meta Info:  [ 7.50225627e+06 -7.64528169e+18  1.69227127e+04]
Val Loss Meta Info:  [

Epoch: 124, NLL Loss: -6.843215332860704e+18, Val Loss: -6.454226053826504e+19, Time took: 0.34308815002441406
Train loss Meta Info:  [ 6.27607477e+06 -6.84321533e+18  1.84527571e+04]
Val Loss Meta Info:  [56130127.567567565, -6.454225780434424e+19, 158137.32432432432]

Epoch: 125, NLL Loss: -7.527424964304692e+18, Val Loss: -6.009210597209748e+19, Time took: 0.34441256523132324
Train loss Meta Info:  [ 6.22448277e+06 -7.52742496e+18  1.85023671e+04]
Val Loss Meta Info:  [56212120.216216214, -6.009210418910565e+19, 158566.05405405405]

Epoch: 126, NLL Loss: -6.389590920100787e+18, Val Loss: -5.961607461187512e+19, Time took: 0.3758084774017334
Train loss Meta Info:  [ 6.20890902e+06 -6.38959092e+18  1.85525044e+04]
Val Loss Meta Info:  [55979741.4054054, -5.961607294774941e+19, 158988.87837837837]

Epoch: 127, NLL Loss: -6.983811792374438e+18, Val Loss: -5.916100873937119e+19, Time took: 0.3736100196838379
Train loss Meta Info:  [ 6.13320439e+06 -6.98381179e+18  1.86019779e+04]
Val Los

Epoch: 155, NLL Loss: -6.276064516638191e+18, Val Loss: -5.0471031783383106e+19, Time took: 0.3631882667541504
Train loss Meta Info:  [ 5.26511097e+06 -6.27606452e+18  1.97138219e+04]
Val Loss Meta Info:  [46667820.972972974, -5.047103083245413e+19, 168885.45945945947]

Epoch: 156, NLL Loss: -5.525419871051063e+18, Val Loss: -4.833004075154512e+19, Time took: 0.3408629894256592
Train loss Meta Info:  [ 5.24307035e+06 -5.52541987e+18  1.97590275e+04]
Val Loss Meta Info:  [46992307.89189189, -4.833004087041125e+19, 169270.47297297296]

Epoch: 157, NLL Loss: -5.532309129962941e+18, Val Loss: -4.758200780679702e+19, Time took: 0.363619327545166
Train loss Meta Info:  [ 5.22704442e+06 -5.53230913e+18  1.98042472e+04]
Val Loss Meta Info:  [46226871.35135135, -4.758200590493907e+19, 169652.82432432432]

Epoch: 158, NLL Loss: -6.108472965567992e+18, Val Loss: -5.700311161674452e+19, Time took: 0.34709692001342773
Train loss Meta Info:  [ 5.16830116e+06 -6.10847297e+18  1.98491925e+04]
Val Loss

Epoch: 186, NLL Loss: -7.996690110331252e+18, Val Loss: -6.1917545962856776e+19, Time took: 0.35157155990600586
Train loss Meta Info:  [ 4.43737343e+06 -7.99669011e+18  2.11410809e+04]
Val Loss Meta Info:  [876889226.3783784, -6.1917545487392285e+19, 181100.17567567568]

Epoch: 187, NLL Loss: -7.143634353833354e+18, Val Loss: -7.021790437647267e+19, Time took: 0.3415226936340332
Train loss Meta Info:  [ 4.42290901e+06 -7.14363435e+18  2.11906504e+04]
Val Loss Meta Info:  [39455439.567567565, -7.02179040198743e+19, 181522.86486486485]

Epoch: 188, NLL Loss: -6.70129573423437e+18, Val Loss: -7.46581853282228e+19, Time took: 0.35215020179748535
Train loss Meta Info:  [ 4.39962093e+06 -6.70129573e+18  2.12401497e+04]
Val Loss Meta Info:  [39577641.51351351, -7.465818699234851e+19, 181946.05405405405]

Epoch: 189, NLL Loss: -7.328231270839902e+18, Val Loss: -4.6988764109007225e+19, Time took: 0.38112974166870117
Train loss Meta Info:  [ 4.36606045e+06 -7.32823127e+18  2.12897361e+04]
Val Lo

## rmtpp

In [13]:
main(model=rmtpp, data=data, val_data=val_data)

Times: Data Shape: torch.Size([50, 8000, 2]), Val Data Shape: torch.Size([50, 2000, 2])
Markers: Data Shape: torch.Size([50, 8000, 20]), Val Data Shape: torch.Size([50, 2000, 20])
Epoch: 0, NLL Loss: 1519.8495, Val Loss: 57723.35546875, Time took: 0.44893574714660645
Train loss Meta Info:  [1444.00798437   75.84150208]
Val Loss Meta Info:  [57307.535, 415.824921875]

Epoch: 1, NLL Loss: 1442.43723046875, Val Loss: 59497.359375, Time took: 0.42549681663513184
Train loss Meta Info:  [1432.83265625    9.6045724 ]
Val Loss Meta Info:  [57215.04, 2282.32]

Epoch: 2, NLL Loss: 1487.3058203125, Val Loss: 59576.58203125, Time took: 0.42228198051452637
Train loss Meta Info:  [1430.13851953   57.16730798]
Val Loss Meta Info:  [57271.32, 2305.2621875]

Epoch: 3, NLL Loss: 1488.50864453125, Val Loss: 59348.046875, Time took: 0.4171335697174072
Train loss Meta Info:  [1431.61246484   56.89618091]
Val Loss Meta Info:  [56879.19, 2468.86203125]

Epoch: 4, NLL Loss: 1480.43335546875, Val Loss: 59564.2

Epoch: 43, NLL Loss: 1462.82735546875, Val Loss: 134957.5625, Time took: 0.41729021072387695
Train loss Meta Info:  [1418.78383594   44.04353467]
Val Loss Meta Info:  [56759.12, 78198.45]

Epoch: 44, NLL Loss: 1463.2329609375, Val Loss: 135050.203125, Time took: 0.4096224308013916
Train loss Meta Info:  [1418.770875     44.46210297]
Val Loss Meta Info:  [56758.64, 78291.57]

Epoch: 45, NLL Loss: 1462.3774765625, Val Loss: 144228.609375, Time took: 0.41159939765930176
Train loss Meta Info:  [1418.76211719   43.61536249]
Val Loss Meta Info:  [56759.435, 87469.14]

Epoch: 46, NLL Loss: 1463.23995703125, Val Loss: 124219.4296875, Time took: 0.4104886054992676
Train loss Meta Info:  [1418.77313281   44.4668443 ]
Val Loss Meta Info:  [56758.155, 67461.28]

Epoch: 47, NLL Loss: 1462.88539453125, Val Loss: 125056.6953125, Time took: 0.41052818298339844
Train loss Meta Info:  [1418.75455469   44.13084631]
Val Loss Meta Info:  [56759.415, 68297.29]

Epoch: 48, NLL Loss: 1462.38334765625, Val Los

Epoch: 87, NLL Loss: 1461.8253203125, Val Loss: 97263.9140625, Time took: 0.4161362648010254
Train loss Meta Info:  [1418.61060547   43.21470374]
Val Loss Meta Info:  [56755.32, 40508.6]

Epoch: 88, NLL Loss: 1461.902234375, Val Loss: 110024.8984375, Time took: 0.41614675521850586
Train loss Meta Info:  [1418.61696484   43.28526318]
Val Loss Meta Info:  [56754.54, 53270.365]

Epoch: 89, NLL Loss: 1461.44580859375, Val Loss: 106974.078125, Time took: 0.4220263957977295
Train loss Meta Info:  [1418.60521484   42.84059485]
Val Loss Meta Info:  [56754.555, 50219.525]

Epoch: 90, NLL Loss: 1461.7337265625, Val Loss: 101962.234375, Time took: 0.4188528060913086
Train loss Meta Info:  [1418.61328125   43.12045563]
Val Loss Meta Info:  [56754.65, 45207.595]

Epoch: 91, NLL Loss: 1461.4495625, Val Loss: 98897.9296875, Time took: 0.42453598976135254
Train loss Meta Info:  [1418.61021094   42.83934485]
Val Loss Meta Info:  [56754.57, 42143.365]

Epoch: 92, NLL Loss: 1461.61555859375, Val Loss: 10

Epoch: 130, NLL Loss: 1460.6780625, Val Loss: 93300.9140625, Time took: 0.42639994621276855
Train loss Meta Info:  [1418.51830469   42.15975555]
Val Loss Meta Info:  [56754.23, 36546.685]

Epoch: 131, NLL Loss: 1461.11359765625, Val Loss: 95401.4609375, Time took: 0.41133594512939453
Train loss Meta Info:  [1418.51592578   42.59768646]
Val Loss Meta Info:  [56754.4, 38647.065]

Epoch: 132, NLL Loss: 1460.64925, Val Loss: 99029.0390625, Time took: 0.41181230545043945
Train loss Meta Info:  [1418.51773047   42.13153912]
Val Loss Meta Info:  [56754.72, 42274.31]

Epoch: 133, NLL Loss: 1460.97366796875, Val Loss: 97087.125, Time took: 0.41220974922180176
Train loss Meta Info:  [1418.524        42.44965088]
Val Loss Meta Info:  [56754.545, 40332.585]

Epoch: 134, NLL Loss: 1460.47495703125, Val Loss: 87549.6015625, Time took: 0.4136812686920166
Train loss Meta Info:  [1418.52108203   41.95388385]
Val Loss Meta Info:  [56754.425, 30795.17]

Epoch: 135, NLL Loss: 1460.8127421875, Val Loss: 92

Epoch: 173, NLL Loss: 1459.9007578125, Val Loss: 87331.2265625, Time took: 0.42354869842529297
Train loss Meta Info:  [1418.44944141   41.45131122]
Val Loss Meta Info:  [56756.355, 30574.875]

Epoch: 174, NLL Loss: 1460.0692265625, Val Loss: 102931.359375, Time took: 0.4225592613220215
Train loss Meta Info:  [1418.44421875   41.62498981]
Val Loss Meta Info:  [56756.73, 46174.61]

Epoch: 175, NLL Loss: 1460.17372265625, Val Loss: 88596.546875, Time took: 0.42085766792297363
Train loss Meta Info:  [1418.44694922   41.72678265]
Val Loss Meta Info:  [56756.83, 31839.7225]

Epoch: 176, NLL Loss: 1460.03976171875, Val Loss: 92093.109375, Time took: 0.41774749755859375
Train loss Meta Info:  [1418.44430469   41.59547321]
Val Loss Meta Info:  [56756.75, 35336.35]

Epoch: 177, NLL Loss: 1460.0770546875, Val Loss: 91416.609375, Time took: 0.4291188716888428
Train loss Meta Info:  [1418.44056641   41.63649988]
Val Loss Meta Info:  [56756.66, 34659.9475]

Epoch: 178, NLL Loss: 1459.87686328125, Va

## hrmtpp

In [14]:
main(model=hrmtpp, data=data, val_data=val_data)

Times: Data Shape: torch.Size([50, 8000, 2]), Val Data Shape: torch.Size([50, 2000, 2])
Markers: Data Shape: torch.Size([50, 8000, 20]), Val Data Shape: torch.Size([50, 2000, 20])
Epoch: 0, NLL Loss: 1516.4166796875, Val Loss: 62809.91015625, Time took: 0.8191773891448975
Train loss Meta Info:  [1.45660838e+03 5.97120386e+01 9.62728846e-02]
Val Loss Meta Info:  [61924.735, 865.41125, 19.767064208984376]

Epoch: 1, NLL Loss: 1570.9606171875, Val Loss: 63048.6171875, Time took: 0.7823503017425537
Train loss Meta Info:  [1.54753283e+03 2.29328391e+01 4.94971273e-01]
Val Loss Meta Info:  [60706.565, 2333.8765625, 8.181640625]

Epoch: 2, NLL Loss: 1576.44079296875, Val Loss: 61415.35546875, Time took: 0.7858977317810059
Train loss Meta Info:  [1.51797221e+03 5.82640049e+01 2.04570010e-01]
Val Loss Meta Info:  [59194.1, 2190.0465625, 31.204609375]

Epoch: 3, NLL Loss: 1534.813140625, Val Loss: 60538.9140625, Time took: 0.7918968200683594
Train loss Meta Info:  [1.47959564e+03 5.44361400e+01 

Epoch: 36, NLL Loss: 1472.6558984375, Val Loss: 233686.78125, Time took: 0.912043571472168
Train loss Meta Info:  [1.42062735e+03 5.19639509e+01 6.46149240e-02]
Val Loss Meta Info:  [56821.565, 176862.7, 2.512484130859375]

Epoch: 37, NLL Loss: 1469.76530859375, Val Loss: 303803.25, Time took: 0.9261817932128906
Train loss Meta Info:  [1.42019372e+03 4.95087677e+01 6.28020414e-02]
Val Loss Meta Info:  [56832.74, 246968.12, 2.417215118408203]

Epoch: 38, NLL Loss: 1476.67630859375, Val Loss: 233066.015625, Time took: 0.8218085765838623
Train loss Meta Info:  [1.42049880e+03 5.61170577e+01 6.04277351e-02]
Val Loss Meta Info:  [56828.6, 176235.14, 2.3116915893554686]

Epoch: 39, NLL Loss: 1466.91131640625, Val Loss: 177829.15625, Time took: 0.8085103034973145
Train loss Meta Info:  [1.42040465e+03 4.64488681e+01 5.77913826e-02]
Val Loss Meta Info:  [56814.39, 121012.58, 2.2045576477050783]

Epoch: 40, NLL Loss: 1469.13708984375, Val Loss: 174373.59375, Time took: 0.7892203330993652
Train 

Epoch: 73, NLL Loss: 1462.6758125, Val Loss: 151667.25, Time took: 0.7836148738861084
Train loss Meta Info:  [1.41889889e+03 4.37725499e+01 4.42096000e-03]
Val Loss Meta Info:  [56774.05, 94893.01, 0.1654181671142578]

Epoch: 74, NLL Loss: 1462.94342578125, Val Loss: 171977.875, Time took: 0.8017199039459229
Train loss Meta Info:  [1.41897706e+03 4.39622522e+01 4.13953975e-03]
Val Loss Meta Info:  [56768.995, 115208.74, 0.15500106811523437]

Epoch: 75, NLL Loss: 1462.580546875, Val Loss: 164839.5625, Time took: 0.8095505237579346
Train loss Meta Info:  [1.41883827e+03 4.37383844e+01 3.87902986e-03]
Val Loss Meta Info:  [56772.385, 108067.03, 0.14542343139648437]

Epoch: 76, NLL Loss: 1462.69693359375, Val Loss: 139525.171875, Time took: 0.8025224208831787
Train loss Meta Info:  [1.41891356e+03 4.37798141e+01 3.63939515e-03]
Val Loss Meta Info:  [56767.475, 82757.56, 0.13668516159057617]

Epoch: 77, NLL Loss: 1462.35306640625, Val Loss: 135727.609375, Time took: 0.8195743560791016
Train

Epoch: 109, NLL Loss: 1462.52032421875, Val Loss: 113844.515625, Time took: 0.8265259265899658
Train loss Meta Info:  [1.41863086e+03 4.38888154e+01 5.70653412e-04]
Val Loss Meta Info:  [56766.93, 57077.56, 0.02152087688446045]

Epoch: 110, NLL Loss: 1462.68553515625, Val Loss: 98876.3359375, Time took: 0.8267881870269775
Train loss Meta Info:  [1.41869123e+03 4.39938277e+01 5.40515340e-04]
Val Loss Meta Info:  [56763.47, 42112.845, 0.020397939682006837]

Epoch: 111, NLL Loss: 1462.720203125, Val Loss: 108804.890625, Time took: 0.7865138053894043
Train loss Meta Info:  [1.41861566e+03 4.41040575e+01 5.12210951e-04]
Val Loss Meta Info:  [56761.98, 52042.89, 0.019310375452041628]

Epoch: 112, NLL Loss: 1461.52734375, Val Loss: 127624.4375, Time took: 0.7904365062713623
Train loss Meta Info:  [1.41857416e+03 4.29526973e+01 4.85025334e-04]
Val Loss Meta Info:  [56764.8, 70859.625, 0.018296420574188232]

Epoch: 113, NLL Loss: 1464.18470703125, Val Loss: 104005.890625, Time took: 0.803702831

Epoch: 145, NLL Loss: 1461.4047421875, Val Loss: 91369.0, Time took: 0.8151271343231201
Train loss Meta Info:  [1.41851662e+03 4.28880074e+01 9.27737650e-05]
Val Loss Meta Info:  [56764.09, 34604.905, 0.0033290630578994753]

Epoch: 146, NLL Loss: 1461.3312265625, Val Loss: 111276.2578125, Time took: 0.8181271553039551
Train loss Meta Info:  [1.41852264e+03 4.28085059e+01 8.79193179e-05]
Val Loss Meta Info:  [56763.08, 54513.18, 0.0031751957535743713]

Epoch: 147, NLL Loss: 1460.66658203125, Val Loss: 109094.140625, Time took: 0.8416290283203125
Train loss Meta Info:  [1.41847478e+03 4.21917319e+01 8.38501276e-05]
Val Loss Meta Info:  [56764.61, 52329.535, 0.003019447922706604]

Epoch: 148, NLL Loss: 1460.6030703125, Val Loss: 92733.328125, Time took: 0.8080394268035889
Train loss Meta Info:  [1.41850906e+03 4.20939410e+01 7.95103858e-05]
Val Loss Meta Info:  [56762.94, 35970.385, 0.0028577205538749696]

Epoch: 149, NLL Loss: 1460.91134765625, Val Loss: 124015.25, Time took: 0.821834802

Epoch: 181, NLL Loss: 1460.1328828125, Val Loss: 105614.28125, Time took: 0.8182809352874756
Train loss Meta Info:  [1.41835580e+03 4.17770632e+01 2.25267973e-05]
Val Loss Meta Info:  [56764.64, 48849.66, 0.0007879181206226349]

Epoch: 182, NLL Loss: 1460.454734375, Val Loss: 84346.578125, Time took: 0.8206558227539062
Train loss Meta Info:  [1.41836198e+03 4.20927329e+01 2.22260865e-05]
Val Loss Meta Info:  [56764.27, 27582.305, 0.0007685727626085281]

Epoch: 183, NLL Loss: 1460.49633984375, Val Loss: 103949.5546875, Time took: 0.7859094142913818
Train loss Meta Info:  [1.41835362e+03 4.21427414e+01 2.16277951e-05]
Val Loss Meta Info:  [56764.09, 47185.465, 0.0007667206972837448]

Epoch: 184, NLL Loss: 1460.11751953125, Val Loss: 95732.8984375, Time took: 0.787747859954834
Train loss Meta Info:  [1.41834431e+03 4.17732163e+01 2.16699491e-05]
Val Loss Meta Info:  [56764.13, 38968.76, 0.000762847512960434]

Epoch: 185, NLL Loss: 1460.01125, Val Loss: 84505.3203125, Time took: 0.81962084