Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning with Marks #8

Closed
cjchristopher opened this issue Aug 24, 2020 · 9 comments
Closed

Learning with Marks #8

cjchristopher opened this issue Aug 24, 2020 · 9 comments

Comments

@cjchristopher
Copy link

Thanks for the release of your paper and code.
In trying to implement learning with marks with the provided interactive notebook, adapting the remarks in the paper, I'm also running into some trouble. Based on appendix F.2. I assume it's a case of just adding the terms?

model.log_prob in this case returns the (time_log_prob, mark_nll, accuracy) - so adapting for the training loop, is it as simple as changing lines as below?:

        log_prob = model.log_prob(input)
        loss = -model.aggregate(log_prob, input.length)

for

        if use_marks:
            log_prob, mark_nll, mark_acc = model.log_prob(input)
        else:
            log_prob = model.log_prob(input)
        loss = -model.aggregate(log_prob + mark_nll, input.length)

As a side problem - when doing the above with my custom dataset (which conforms to the same formatting as the example datasets, so arrival_times and marks), all loss terms are NaN. I'm wondering if you might have some insight as to why this might be occurring! When using the reddit dataset with the above modifications, I get non-zero loss terms for both log_prob and mark_nll.

@shchur
Copy link
Owner

shchur commented Aug 24, 2020

It could be that mark_nll already contains the negative log-likelihood, so you don't need to negate it, as you do with log_prob (I'm not 100% certain though).

If this doesn't solve the problem, here are a few follow-up questions:

  1. Do the NaNs appear after several iterations or do you get them from the start?
  2. Do the NaNs appear in the log_prob term or in mark_nll?
  3. Does your data contain duplicate time stamps (i.e. two or more events happening exactly at the same time)? This may lead to numerical issues.

@cjchristopher
Copy link
Author

Thanks for the pointers - the data did indeed have duplicate timestamps - I've cleaned those up now.
My new problem, which I'd also appreciate any insight on, is the log_prob terms running very negative very quickly (values in the tensor say, all ~-7), with the aggregation also subsequently running negative quite quickly. My dataset is perhaps a little more dense in time than the reddit one (surely just a scaling issue?), but with many fewer mark classes. I don't suppose you would know what feature of the dataset would cause this?

I should note that my mark_nll terms are still positive.

@shchur
Copy link
Owner

shchur commented Aug 26, 2020

In general, there is nothing wrong with having very negative values. As you said, this just reflects different scaling of the data. Simply rescaling the arrival (or inter-event times) should fix this. I guess a good idea is to rescale the times such that the average inter-arrival time is equal to one. It's important, though, that you scale all the sequences by the same factor.

A simple example to demonstrate the above point: Imagine having a uniform distribution p(x) = Uniform([0, 1]), the log-density of any sample x \in [0, 1] is log p(x) = 0. However, if you simply rescale all the samples y = x * 1000, then p(y) = Uniform([0, 1000]) with log p(y) = -log(1000) for any y \in [0, 1000]. A very similar thing happens with TPP densities, but the scaling is not as straightforward, as it also depends on the number of events.

Do you still get NaNs now after removing the duplicates?

@cjchristopher
Copy link
Author

cjchristopher commented Aug 27, 2020

Yes, removing the duplicates got rid of the NaNs. Thanks much

I've rescaled such that the average delta is 1 - although I do still see negative loss - some of my data points, even after scaling, are still very close together. The distribution of the inter-event times is very bimodal in my dataset which maybe poses a problem - I'll reduce the number of mixture components as per #5 (comment).

It's mentioned in the paper that you normalise the loss by subtracting the score of LogNormMix - is this already done in the code you have provided here? I see that model.log_prob ultimately ends up calling self.decoder.log_prob (where decoder = LogNormMix), so I guess it is, or is there something else required?

Lastly, I'm wondering if you had implemented at some point simulation/sampling with marks as well? - with reference to your response #6 (comment), I guess it would need to draw from model.mark_layer. Is there a decoder for marks that would be required?

Thank you very much again for your time!

@shchur
Copy link
Owner

shchur commented Aug 27, 2020

Subtract the loss of LogNormMix is done only for visualization in Figure 3. As I said before, we could arbitrarily shift the loss values for all models by the same amount by rescaling the inter-event times, so the absolute value of the loss for each model is irrelevant, only the differences between the models are (e.g. if two models have losses 200.1 and 200.5, we could change them to 0.1 and 0.5 by simple rescaling).

In case of marks, you would need to create a categorical distribution to sample the marks from

        x = self.mark_layer(h)
        x = F.log_softmax(x, dim=-1)
        mark_distribution = torch.distributions.Categorical(logits=x)

@cjchristopher
Copy link
Author

Ah okay. Thanks for clarifying.

Since it relates to learning and simulation specifically in the case of marks I'll mention it here - I was able to use your code provided in the other issue for simulation without marks, but had some errors which I'm also not entirely sure how to correct when trying to sample from a model that has been trained with marks;

Notably:
RuntimeError: input.size(-1) must be equal to input_size. Expected <history_size+1>, got 1
wrt next_in_time = torch.zeros(1, 1, 1)
Naively changing the last term to be the expected size, then produces:
RuntimeError: Expected hidden[0] size (1, 1, history_size), got (1, history_size)

I don't quite think I'm resolving that correctly - any guidance is appreciated.

@shchur
Copy link
Owner

shchur commented Aug 27, 2020

Here is the code that should work

from torch.distributions import Categorical

next_in_time = torch.zeros(1, 1, 1)
next_mark_emb = torch.zeros(1, 1, general_config.mark_embedding_size)
h = torch.zeros(1, 1, history_size)
inter_times = []
marks = []
t_max = 1000
with torch.no_grad():
    while sum(inter_times) < t_max:
        rnn_input = torch.cat([next_in_time, next_mark_emb], dim=-1)
        _, h = model.rnn.step(rnn_input, h)
        tau = model.decoder.sample(1, h)
        inter_times.append(tau.item())
        next_in_time = ((tau + 1e-8).log() - mean_in_train) / std_in_train
        
        mark_logits = model.mark_layer(h)
        mark_dist = Categorical(logits=mark_logits)
        next_in_mark = mark_dist.sample()
        marks.append(next_in_mark.item())
        next_mark_emb = model.rnn.mark_embedding(next_in_mark)

@cjchristopher
Copy link
Author

Great! Thanks.

I've managed to modify if slightly so that it works with an LSTM, although that raised one additional question, since the LSTM hidden state output is a tuple.
For: _, h = model.rnn.step(rnn_input, h), I'm not sure why we are passing the hidden state, rather than the output encoding to the decoder? Should it be h, _ = to retrieve the encoding of the history for the decoder? Apologies if there is some fundamental misunderstanding!

@shchur
Copy link
Owner

shchur commented Aug 29, 2020

It's up to you to decide whether to use the hidden state or the output of the LSTM to obtain the conditional distribution. I don't have a strong intuition here. Probably, both version should work equally well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants