In [116]:
import numpy as np

import torch
import torch.utils.data as data_utils

from pathlib import Path
from sklearn.model_selection import train_test_split

from data import Sequence,SequenceDataset,get_inter_times,create_seq_data_set
from model import LogNormMix,LogNormalMixtureDistribution
from evaluation import get_prediction_for_all_events, get_prediction_for_last_events

from copy import deepcopy
from torch.distributions import Categorical
from sklearn.metrics import f1_score
import pickle

from util import clamp_preserve_gradients

## Load Datasets

In [55]:
dataset_name = 'data/simulated/hawkes_synthetic_random_2d_20191130-180837.pkl'  # run dpp.data.list_datasets() to see the list of available datasets
dataset = np.load(dataset_name,allow_pickle = True)

## Modify the dataset for IFTPL




In [56]:
sequences = []

# for i  in range(len(dataset['timestamps'])):
for i  in range(400):
    sequence = {'t_start':0,'t_end' :200, 'arrival_times':dataset['timestamps'][i].tolist(),'marks' :dataset['types'][i].tolist()}
    
    sequences.append(sequence)
    
simulated_data = {'sequences':sequences,'num_marks':2}

# np.save('data/simulated/sample_hawkes',simulated_data,allow_pickle = True)

### Training

In [177]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)


# Model config
context_size = 64                 # Size of the RNN hidden vector
mark_embedding_size = 32          # Size of the mark embedding (used as RNN input)
num_mix_components = 64           # Number of components for a mixture model
rnn_type = "GRU"                  # What RNN to use as an encoder {"RNN", "GRU", "LSTM"}

# Training config
batch_size = 32        # Number of sequences in a batch
regularization = 1e-5  # L2 regularization parameter
learning_rate = 1e-3   # Learning rate for Adam optimizer
max_epochs = 5      # For how many epochs to train
display_step = 1      # Display training statistics after every display_step
patience = 50          # After how many consecutive epochs without improvement of val loss to stop training

test_dataset = {}
test_dataset['sequences']= [{'t_start': 0,'t_end': 10,'arrival_times': [1,2],'marks':[0,1]},
                            {'t_start': 0,'t_end': 10,'arrival_times': [1,3,5],'marks':[0,1,0]},
                            {'t_start': 0,'t_end': 10,'arrival_times': [1,7,8,9],'marks':[0,1,1,0]},
                            {'t_start': 0,   't_end': 10,'arrival_times': [5,9],'marks':[1,1]}  ]

def get_inter_times(seq: dict):
    """Get inter-event times from a sequence."""
    return np.ediff1d(np.concatenate([[seq["t_start"]], seq["arrival_times"], [seq["t_end"]]]))

sequences = [
    Sequence(
        inter_times=get_inter_times(seq),
        marks=seq.get("marks"),
        t_start=seq.get("t_start"),
        t_end=seq.get("t_end")
    )
    for seq in simulated_data["sequences"]
]

seed = 0
batch_size = 5
dataset = SequenceDataset(sequences=sequences, num_marks=2)

d_train, d_val, d_test = dataset.train_val_test_split(seed=None,shuffle = False)

training_events= d_train.total_num_events

dl_train = d_train.get_dataloader(batch_size=batch_size, shuffle=False)
dl_val = d_val.get_dataloader(batch_size=batch_size, shuffle=False)
dl_test = d_test.get_dataloader(batch_size=batch_size, shuffle=False)

# dataset_name = 'data/stack_overflow.pkl'  # run dpp.data.list_datasets() to see the list of available datasets
# dataset = torch.load(dataset_name)


# def get_inter_times(seq: dict):
#     """Get inter-event times from a sequence."""
#     return np.ediff1d(np.concatenate([[seq["t_start"]], seq["arrival_times"], [seq["t_end"]]]))


# sequences = [
#     Sequence(
#         inter_times=get_inter_times(seq),
#         marks=seq.get("marks"),
#         t_start=seq.get("t_start"),
#         t_end=seq.get("t_end")
#     )
#     for seq in dataset["sequences"]
# ]
# dataset = SequenceDataset(sequences=sequences, num_marks=dataset.get("num_marks", 1))

# d_train, d_val, d_test = dataset.train_val_test_split(seed=seed)

# dl_train = d_train.get_dataloader(batch_size=batch_size, shuffle=True)
# dl_val = d_val.get_dataloader(batch_size=batch_size, shuffle=False)
# dl_test = d_test.get_dataloader(batch_size=batch_size, shuffle=False)





In [178]:

# Define the model
print('Building model...')
mean_log_inter_time, std_log_inter_time = d_train.get_inter_time_statistics()

model =LogNormMix(
    num_marks=d_train.num_marks,
    mean_log_inter_time=mean_log_inter_time,
    std_log_inter_time=std_log_inter_time,
    context_size=context_size,
    mark_embedding_size=mark_embedding_size,
    rnn_type=rnn_type,
    num_mix_components=num_mix_components,
)

opt = torch.optim.Adam(model.parameters(), weight_decay=regularization, lr=learning_rate)
# Traning
print('Starting training...')

def aggregate_loss_over_dataloader(dl):
    total_loss = 0.0
    total_count = 0
    with torch.no_grad():
        for batch in dl:
            total_loss += -model.log_prob(batch).sum()
            total_count += batch.mask.sum().item()
    return total_loss / total_count


impatient = 0
best_loss = np.inf
best_model = deepcopy(model.state_dict())
training_val_losses = []

Building model...
Starting training...


In [179]:
for epoch in range(max_epochs):
    epoch_train_loss = 0
    model.train()
    for batch in dl_train:
        opt.zero_grad()
        # loss = -model.log_prob(batch)
        loss = -model.log_prob(batch).sum()
        loss.backward()
        epoch_train_loss += loss.detach()

        opt.step()

    model.eval()
    with torch.no_grad():
        loss_val = aggregate_loss_over_dataloader(dl_val)
        loss_test = aggregate_loss_over_dataloader(dl_test)

        training_val_losses.append(loss_val)

    if (best_loss - loss_val) < 1e-4:
        impatient += 1
        if loss_val < best_loss:
            best_loss = loss_val
            best_model = deepcopy(model.state_dict())
    else:
        best_loss = loss_val
        best_model = deepcopy(model.state_dict())
        impatient = 0

    if impatient >= patience:
        print(f'Breaking due to early stopping at epoch {epoch}')
        break
    
    
    epoch_train_loss = epoch_train_loss/training_events

    if epoch % display_step == 0:
        print(f"Epoch {epoch:4d}: Training loss = {epoch_train_loss.item():.4f}, loss_val = {loss_val:.4f}")


# Evaluation
model.load_state_dict(best_model)
model.eval()

# All training & testing sequences stacked into a single batch
with torch.no_grad():
    final_loss_train = aggregate_loss_over_dataloader(dl_train)
    final_loss_val = aggregate_loss_over_dataloader(dl_val)
    final_loss_test = aggregate_loss_over_dataloader(dl_test)

print(f'Negative log-likelihood:\n'
      f' - Train: {final_loss_train:.4f}\n'
      f' - Val:   {final_loss_val:.4f}\n'
      f' - Test:  {final_loss_test:.4f}')




Epoch    0: Training loss = 1.8333, loss_val = 1.8100
Epoch    1: Training loss = 1.7969, loss_val = 1.8056
Epoch    2: Training loss = 1.7926, loss_val = 1.8009
Epoch    3: Training loss = 1.7895, loss_val = 1.7984
Epoch    4: Training loss = 1.7871, loss_val = 1.7972
Negative log-likelihood:
 - Train: 1.8
 - Val:   1.8
 - Test:  1.7


### Event Prediction 

In [201]:
def get_prediction_for_all_events(model,dl):
    
    total_num_events = dl.dataset.total_num_events
    all_event_time_predictions = []
    all_event_time_values = []
    all_actual_marks = []
    all_predicted_marks = []
    
    
    for batch in dl:
        
        lengths = batch.mask.sum(-1)-1 ## Minus 1 because they also calculate the survival for the end of time.

        
        features = model.get_features(batch)
        context = model.get_context(features)
        inter_time_dist = model.get_inter_time_dist(context)
        inter_times = batch.inter_times.clamp(1e-10)
        
        ## Arrival Time Prediction
        predicted_times = inter_time_dist.mean        
        all_predicted_times = torch.nn.utils.rnn.pack_padded_sequence(predicted_times.T,lengths,batch_first=False,enforce_sorted=False)[0]
        all_actual_times = torch.nn.utils.rnn.pack_padded_sequence(inter_times.T,lengths,batch_first=False,enforce_sorted=False)[0]
        all_event_time_values.append(all_actual_times[:-1])
        all_event_time_predictions.append(all_predicted_times[:-1])
        


#         ## Mark Prediction
        predicted_marks= torch.log_softmax(model.mark_linear(context), dim=-1).argmax(-1)
        predicted_marks = torch.nn.utils.rnn.pack_padded_sequence(predicted_marks.T,lengths,batch_first=False,enforce_sorted=False)[0]
        actual_marks = torch.nn.utils.rnn.pack_padded_sequence(batch.marks.T,lengths,batch_first=False,enforce_sorted =False)[0]
        all_actual_marks.append(actual_marks)
        all_predicted_marks.append(predicted_marks)
        
#         all_event_accuracy = (predicted_marks ==batch.marks)*batch.mask
#         all_event_accuracies.append(all_event_accuracy)
#         last_event_accuracy = all_event_accuracy[x_index,y_index]
#         last_event_accuracies.append(last_event_accuracy)
        
    
#     last_event_rmse = (torch.cat(last_event_errors,axis = 0)**2).mean().sqrt()
#     all_event_rmse = ((torch.cat(total_errors,-1)**2).sum()/total_num_events).sqrt()
#     all_event_accuracy = torch.cat(all_event_accuracies,-1).sum()/total_num_events
#     last_event_accuracy = torch.cat(last_event_accuracies,0)
# #     last_event_accuracy = last_event_accuracy.sum()/len(last_event_accuracy)
    
#     return last_event_rmse,all_event_rmse,all_event_accuracy,last_event_accuracy
    
    
    
    all_event_time_values = torch.cat(all_event_time_values)
    all_event_time_predictions = torch.cat(all_event_time_predictions)
    all_actual_marks = torch.cat(all_actual_marks)
    all_predicted_marks = torch.cat(all_predicted_marks)
    
    return all_event_time_values,all_event_time_predictions,all_actual_marks,all_predicted_marks

In [202]:
def get_prediction_for_last_events(model,dl):
    
    event_time_predictions = []
    event_time_values = []
    actual_marks = []
    predicted_marks = []
    
    
    for batch in dl:
        
        y_index = batch.mask.sum(-1).long() -1  ## Minus 1 because they also calculate the survival for the end of time.
        features = model.get_features(batch)
        context = model.get_context(features)
        inter_time_dist = model.get_inter_time_dist(context)
        inter_times = batch.inter_times.clamp(1e-10)
        x_index = torch.arange(0,len(inter_times))
        
        ## Arrival Time Prediction
        actual_time = inter_times[x_index,y_index]   
        predicted_time = inter_time_dist.mean[x_index,y_index]     

        ## Mark Prediction
        actual_mark = batch.marks[x_index,y_index]
        predicted_mark= torch.log_softmax(model.mark_linear(context), dim=-1).argmax(-1)[x_index,y_index]
        
        event_time_predictions.append(predicted_time)
        event_time_values.append(actual_time)
        actual_marks.append(actual_mark)
        predicted_marks.append(predicted_mark)
    
    event_time_values = torch.cat(event_time_values)
    event_time_predictions = torch.cat(event_time_predictions)
    actual_marks = torch.cat(actual_marks)
    predicted_marks = torch.cat(predicted_marks)
    
    return event_time_values,event_time_predictions,actual_marks,predicted_marks

In [208]:
batch.marks.device

device(type='cpu')

In [203]:
actual_times,predicted_times,actual_marks,predicted_marks = get_prediction_for_all_events(model,dl_test)

RMSE = (((predicted_times - actual_times)/actual_times)**2).mean().sqrt()
f1= f1_score(predicted_marks.detach().numpy(),actual_marks.detach().numpy())

print(RMSE.item())
print(f1)

218.84422302246094
0.550259965337955


In [204]:
actual_times,predicted_times,actual_marks,predicted_marks = get_prediction_for_last_events(model,dl_test)

RMSE = (((predicted_times - actual_times)/actual_times)**2).mean().sqrt()
f1= f1_score(predicted_marks.detach().numpy(),actual_marks.detach().numpy())

print(RMSE.item())
print(f1)

13.017912864685059
0.4571428571428572


In [217]:
batch.marks.cpu()

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]])

## Check Results

### Simulated

In [224]:

dataset_name = 'simulated'
num_marks = 2
dataset_path = 'data/' + dataset_name + '/'  # run dpp.data.list_datasets() to see the list of available datasets

with open(dataset_path + 'train.pkl', 'rb') as f:
    train = pickle.load(f)
with open(dataset_path + 'valid.pkl', 'rb') as f:
    valid = pickle.load(f)
with open(dataset_path + 'test.pkl', 'rb') as f:
    test = pickle.load(f)


def create_seq_data_set(dataset, num_marks,device):
    sequences = [
        Sequence(
            inter_times=get_inter_times(seq),
            marks=seq.get("marks"),
            t_start=seq.get("t_start"),
            t_end=seq.get("t_end"),device = device
        )
        for seq in dataset["sequences"]
    ]
    dataset = SequenceDataset(sequences=sequences, num_marks=num_marks)

    return dataset

device ='cpu'
d_train = create_seq_data_set(train, num_marks,device)
d_val = create_seq_data_set(valid, num_marks,device)
d_test = create_seq_data_set(test, num_marks,device)

dl_train = d_train.get_dataloader(batch_size=batch_size, shuffle=False)
dl_val = d_val.get_dataloader(batch_size=batch_size, shuffle=False)
dl_test = d_test.get_dataloader(batch_size=batch_size, shuffle=False)

# Define the model
print('Building model...')
mean_log_inter_time, std_log_inter_time = d_train.get_inter_time_statistics()

model =LogNormMix(
    num_marks=d_train.num_marks,
    mean_log_inter_time=mean_log_inter_time,
    std_log_inter_time=std_log_inter_time,
    context_size=context_size,
    mark_embedding_size=mark_embedding_size,
    rnn_type=rnn_type,
    num_mix_components=num_mix_components,
)

model = model.to(device)


model_dict =torch.load('intensity_free_modelsimulated',map_location=torch.device('cpu'))
model.load_state_dict(model_dict)

Building model...


<All keys matched successfully>

In [226]:
actual_times,predicted_times,actual_marks,predicted_marks = get_prediction_for_last_events(model,dl_test)

RMSE = (((predicted_times - actual_times)/actual_times)**2).mean().sqrt()
f1= f1_score(predicted_marks.detach().numpy(),actual_marks.detach().numpy())

print(RMSE.item())
print(f1)

18.669763565063477
0.588495575221239


### Mimic

In [23]:
batch_size = 65
dataset_name = 'mimic'
num_marks = 75
dataset_path = 'data/' + dataset_name + '/'  # run dpp.data.list_datasets() to see the list of available datasets

with open(dataset_path + 'train.pkl', 'rb') as f:
    train = pickle.load(f)
with open(dataset_path + 'valid.pkl', 'rb') as f:
    valid = pickle.load(f)
with open(dataset_path + 'test.pkl', 'rb') as f:
    test = pickle.load(f)




device ='cpu'
d_train = create_seq_data_set(train, num_marks,device)
d_val = create_seq_data_set(valid, num_marks,device)
d_test = create_seq_data_set(test, num_marks,device)

dl_train = d_train.get_dataloader(batch_size=batch_size, shuffle=False)
dl_val = d_val.get_dataloader(batch_size=batch_size, shuffle=False)
dl_test = d_test.get_dataloader(batch_size=batch_size, shuffle=False)

context_size = 64  # Size of the RNN hidden vector
mark_embedding_size = 32  # Size of the mark embedding (used as RNN input)
num_mix_components = 64  # Number of components for a mixture model
rnn_type = "GRU"  # What RNN to use as an encoder {"RNN", "GRU", "LSTM"}

# Training config
batch_size = 32  # Number of sequences in a batch
regularization = 1e-5  # L2 regularization parameter
learning_rate = 1e-4  # Learning rate for Adam optimizer
max_epochs = 5  # For how many epochs to train
display_step = 5  # Display training statistics after every display_step
patience = 50  # After how many consecutive epochs without improvement of val loss to stop training

# Define the model
print('Building model...')
mean_log_inter_time, std_log_inter_time = d_train.get_inter_time_statistics()

model =LogNormMix(
    num_marks=d_train.num_marks,
    mean_log_inter_time=mean_log_inter_time,
    std_log_inter_time=std_log_inter_time,
    context_size=context_size,
    mark_embedding_size=mark_embedding_size,
    rnn_type=rnn_type,
    num_mix_components=num_mix_components,
)

model = model.to(device)


model_dict =torch.load('intensity_free_modelmimic',map_location=torch.device('cpu'))
model.load_state_dict(model_dict)

Building model...


<All keys matched successfully>

In [24]:
actual_times,predicted_times,actual_marks,predicted_marks = get_prediction_for_last_events(model,dl_test)

RMSE = (((predicted_times - actual_times)/actual_times)**2).mean().sqrt()
f1= f1_score(predicted_marks.detach().numpy(),actual_marks.detach().numpy(),average = 'micro')

print(RMSE.item())
print(f1)

inf
0.8769230769230769


In [60]:
predicted_times

tensor([1.1749e+02, 4.3881e-01, 3.1367e+01, 1.4356e+07, 6.0000e+00, 1.6093e+00,
        7.5512e+00, 9.4867e+00, 6.8709e+00, 6.4088e+02, 5.1162e+07, 1.5415e+00,
        4.3675e-01, 6.8879e-01, 2.8085e+00, 5.0541e-01,        inf, 1.0922e+00,
        5.8352e+00, 1.8536e+00, 7.4858e+00, 6.9754e+00, 4.5153e+03, 6.8216e+00,
        1.4159e+00,        inf, 3.4644e+00, 5.3584e-01,        inf, 5.9234e-01,
        1.7005e+00, 4.4791e-01, 2.8797e+00, 1.6537e+00, 1.5182e+00, 1.5238e+00,
        5.4567e-01, 7.1876e+00, 2.2442e+00, 3.3967e+00, 3.3967e+00, 1.6982e+00,
        4.9693e-01, 2.9801e+00, 3.8397e+00, 7.3661e+00, 1.2535e+00, 4.8983e-01,
        5.7855e-01, 1.0930e+00, 2.0347e+09, 9.0467e+00, 3.6998e-01, 3.8227e+00,
        7.1085e+00, 5.9449e+00, 3.0415e+00, 4.5754e-01, 7.6509e+00, 1.6982e+00,
        5.3584e-01, 7.8676e+00, 1.6659e+00, 5.2245e-01,        inf],
       grad_fn=<CatBackward>)

In [28]:
features = model.get_features(batch)
context = model.get_context(features)
inter_time_dist = model.get_inter_time_dist(context)

In [62]:
batch.inter_times[-1]

tensor([1.0000e-10, 3.0769e-01, 3.4615e-01, 5.3462e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])

In [76]:
batch.inter_times[-1]

tensor([1.0000e-10, 3.0769e-01, 3.4615e-01, 5.3462e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])

In [75]:
batch.marks[-1]

tensor([20, 20, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [73]:
batch.mask[-1].

tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [63]:
inter_time_dist.mean[-1]

tensor([1.3683e+30, 2.6320e+35,        inf,        inf, 7.9683e+01, 2.1732e+01,
        1.0686e+01, 9.4735e+00, 9.6479e+00, 1.0049e+01, 1.0429e+01, 1.0733e+01],
       grad_fn=<SelectBackward>)

In [38]:
batch.inter_times[-1]

tensor([1.0000e-10, 3.0769e-01, 3.4615e-01, 5.3462e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])

In [25]:
for batch in dl_test:
    pass

In [130]:
y_index = batch.mask.sum(
    -1).long() - 1  ## Minus 1 because they also calculate the survival for the end of time.
features = model.get_features(batch)
context = model.get_context(features)
inter_time_dist = model.get_inter_time_dist(context)


raw_params = model.linear(context)  # (batch_size, seq_len, 3 * num_mix_components)
# Slice the tensor to get the parameters of the mixture
locs = raw_params[..., :model.num_mix_components]
log_scales = raw_params[..., model.num_mix_components: (2 * model.num_mix_components)]
log_weights = raw_params[..., (2 * model.num_mix_components):]

log_scales = clamp_preserve_gradients(log_scales, -5.0, 1.0)
log_weights = torch.log_softmax(log_weights, dim=-1)
inter_time_dist= LogNormalMixtureDistribution(
    locs=locs,
    log_scales=log_scales,
    log_weights=log_weights,
    mean_log_inter_time=model.mean_log_inter_time,
    std_log_inter_time=model.std_log_inter_time
)



a = inter_time_dist.std_log_inter_time
b = inter_time_dist.mean_log_inter_time
loc = inter_time_dist.base_dist._component_distribution.loc
variance = inter_time_dist.base_dist._component_distribution.variance
log_weights = inter_time_dist.base_dist._mixture_distribution.logits
(log_weights + a * loc + b + 0.5 * a ** 2 * variance).logsumexp(-1)[-1][2]

tensor(350.4573, grad_fn=<SelectBackward>)

In [131]:
model.std_log_inter_time

tensor(9.7428)

In [None]:
\sum_k w_k \exp(a * \mu_k + b + a^2 * s_k^2 / 2)

In [108]:
(a * loc[-1][2]) + b+ (0.5*variance[-1][2] * a**2)

tensor([ 5.3827e-02,  2.4959e+00,  5.6551e+00,  9.6110e-01,  1.6284e+01,
         1.0101e+00,  3.3274e+00,  4.1098e+01,  7.5747e+00,  1.8718e+00,
        -1.3737e+00,  3.4822e+01,  8.1448e+00,  1.8703e+01, -2.8655e+00,
         5.7551e+02,  6.7150e-01, -1.7918e+00,  4.2643e+00,  6.0229e+00,
         1.8080e+01,  2.7982e+00,  7.7480e+00,  2.0706e+00,  9.0482e-01,
         1.3496e+00,  1.0050e+01,  1.2869e+00,  3.7434e+00,  5.2432e+00,
        -7.3683e-01,  7.0787e-01,  9.2646e+00,  2.1809e-01,  2.8906e+00,
         1.1255e+00, -1.1422e+00, -5.0909e-01,  3.0159e+00,  5.1239e+00,
         6.7002e-01,  1.7750e+02,  2.3232e+00,  2.4128e+00,  1.8722e+00,
         5.1205e+00,  7.2887e+00,  2.4980e-01,  1.3400e+00, -3.3546e+00,
         3.8095e-01,  1.0204e+01, -3.1551e+00,  2.8819e-01,  6.4451e+01,
        -2.5328e+00,  1.1071e+01, -7.6883e-01,  3.5442e+01, -4.6588e-01,
        -2.8729e+00,  1.5414e+00,  3.0847e+02, -2.2779e+00],
       grad_fn=<AddBackward0>)

In [93]:
(log_weights + a * loc + b + 0.5 * a ** 2 * variance)[-1][2]

tensor([-2.7719e+00, -2.4933e+00,  3.2127e-01, -2.6305e+00,  1.0638e+01,
        -3.2886e+00, -1.0176e+00,  3.4976e+01,  2.0127e+00, -2.0520e+00,
        -5.5669e+00,  2.8941e+01,  2.8677e+00,  1.3045e+01, -6.8826e+00,
         5.6830e+02, -3.5138e+00, -5.4087e+00, -9.0726e-01,  1.1996e+00,
         1.2257e+01, -1.5404e+00,  2.2187e+00, -2.3010e+00, -2.3071e+00,
        -1.6139e+00,  4.6739e+00, -1.7721e+00, -1.1501e+00,  3.2883e-01,
        -4.4762e+00, -3.1934e+00,  3.7948e+00, -3.6728e+00, -1.2871e+00,
        -3.8649e+00, -4.4348e+00, -4.7654e+00, -1.4894e+00, -5.6178e-01,
        -3.3523e+00,  1.7069e+02, -2.0465e+00, -2.2022e+00, -2.6822e+00,
         4.8712e-02,  1.9735e+00, -3.6402e+00, -2.7830e+00, -6.9390e+00,
        -4.2677e+00,  4.9924e+00, -6.5031e+00, -3.9428e+00,  5.8077e+01,
        -6.1407e+00,  5.1859e+00, -4.4552e+00,  2.9361e+01, -3.4964e+00,
        -6.2605e+00, -2.7168e+00,  3.0133e+02, -5.2975e+00],
       grad_fn=<SelectBackward>)

In [92]:
a = inter_time_dist.std_log_inter_time
b = inter_time_dist.mean_log_inter_time
loc = inter_time_dist.base_dist._component_distribution.loc
variance = inter_time_dist.base_dist._component_distribution.variance
log_weights = inter_time_dist.base_dist._mixture_distribution.logits


tensor(568.3008, grad_fn=<SelectBackward>)