Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean function support for spatiotemporal GPs #21

Merged
merged 19 commits into from Oct 14, 2022

Conversation

sethaxen
Copy link
Contributor

@sethaxen sethaxen commented Sep 28, 2022

This PR addresses two issues:

  • Though children of SpatioTemporalBase take a mean function as argument, they don't seem to use the function at all.
  • The mean functions currently implemented all support only a single temporal dimension and no spatial dimensions.

This PR adds support for using separable spatial and temporal mean functions within the current children of SpatioTemporalBase. It does this by adding an abstract class SpatioTemporalMeanFunction and a subclass SeparableMeanFunction, which is the sum of a spatial mean function (should be a GPflow mean function) and a temporal mean function (should be a markovflow mean function).

I'm not certain whether the separability is a requirement or not, but since ConditionalPosterior needs a pure temporal mean function, it seemed much simpler to support these than to support all possible spatiotemporal mean functions. I'm also not convinced I have overloaded all methods of SpatioTemporalBase necessary for all functionality to correctly use the mean function.

@vincentadam87 does this look like a suitable approach? If so I will

  • add a unit test for the new mean function
  • add a notebook demonstrating use of a spatial mean function, which will also be an integration test
  • add an integration test using the mean function

Copy link
Collaborator

@vincentadam87 vincentadam87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can't really distribute the temporal and spatial means across the prediction the way you do (see detailed comment)
So there is not a real need to have the Separable vs non separable mean abstraction.

yes for test! (not sure what scenario makes sense though, we do not have tests for spatio temporal at all actually, bad. any test welcomed)

notebook: we have one, maybe just add a mean there in the existing demo?

markovflow/models/spatio_temporal_variational.py Outdated Show resolved Hide resolved
@sethaxen sethaxen marked this pull request as draft October 7, 2022 14:18
@sethaxen sethaxen marked this pull request as ready for review October 11, 2022 12:54
@sethaxen
Copy link
Contributor Author

sethaxen commented Oct 11, 2022

@vincentadam87 I have implemented the mean function approach you suggested (adding the mean function at the end of space_time_predict). With this mean function specified, the tests implemented in #23 now pass for SpatioTemporalSparseVariational, but for SpatioTemporalSparseCVI both the ELBO calculation and the mean prediction disagree with GPR (EDIT: varying the learning rate or the number of times update_sites is called has no effect). Are there other places in SpatioTemporalSparseCVI where the mean function needs to be taken into account?

@vincentadam87
Copy link
Collaborator

my hunch is that you need to substract the mean function somewhere along the way in the site_update calculation.

I would need to do the math properly but it is possibly in the input to the 'gradient_transformation_mean_var_to_expectation'

@sethaxen
Copy link
Contributor Author

my hunch is that you need to substract the mean function somewhere along the way in the site_update calculation.

I would need to do the math properly but it is possibly in the input to the 'gradient_transformation_mean_var_to_expectation'

That did the trick! Now the tests pass. This approach ends up computing the mean function once more than is necessary, but I don't believe this is the computational bottleneck, and doing it differently would require more code restructuring.

Copy link
Collaborator

@vincentadam87 vincentadam87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks perfect to me!

@vincentadam87 vincentadam87 merged commit 22968ce into secondmind-labs:develop Oct 14, 2022
@sethaxen sethaxen deleted the spatialmean branch October 14, 2022 12:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants