Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare exp and softplus transform using synthetic data (cell2location model) #14

Merged
merged 22 commits into from
May 10, 2021

Conversation

vitkl
Copy link
Contributor

@vitkl vitkl commented Apr 12, 2021

Adding the experiments which compare the effect of exp and softplus transform on the cell2location model estimation (https://github.com/BayraktarLab/cell2location). Stability of ELBO and accuracy using ground truth estimates (R^2, PR curves) are compared on synthetic data. The cell2location model is ported from pymc3 to both numpyro and pyro:

  1. Pyro, 2021-03-softplus_scales/cell2location_model.py, 2021-03-softplus_scales/cell2location_synthetic_data.ipynb
  2. Numpyro, 2021-03-softplus_scales/cell2location_model_numpyro.py, 2021-03-softplus_scales/cell2location_synthetic_data_numpyro.ipynb

The model is slightly different to the original pymc3 implementation:

  1. Gamma(mu, sigma) had to be re-parameterised to Gamma(alpha, beta) because PyTorch and numpyro do not support Gamma(mu, sigma).
  2. Negative Binomial distrution is different in all 3 cases: pymc3 uses NB(mu, alpha also called theta and total count), pyro uses NB(logits=log(mu) - log(alpha), total count) and numpyro uses GammaPoisson(alpha, beta=alpha / mu).

Three conditions are compared:

  1. Ext used for all positive transformations
  2. Softplus used for transforming AutoNormal scales.
  3. Softplus used for all positive transformations

For numpyro, the results are as follows:

  1. Ext leads to exploding ELBO (see plot below).
  2. Softplus for scales improves the stability of ELBO (consistent with findings reported in Softplus transform as a more numerically stable way to enforce positive constraint numpyro#855) but, surprisingly, has low accuracy compared to pymc3. I am not sure what might be driving this. I did not use pyro.plates when I did the original experiments (with numpyro version 0.4.1, I did not use pyro).
  3. Softplus for all gives accuracy similar to the original.

For pyro, for some reason, this implementation gives NaN in all three comparisons after just a few iterations. Any thoughts about potential solutions would be appreciated. Maybe I am using the plates interface incorrectly?

image

@fritzo fritzo added the WIP label Apr 12, 2021
@vitkl
Copy link
Contributor Author

vitkl commented Apr 13, 2021

The issue with pyro was overlapping plate dimension - now replaced with expand/to_event (excluding obs_axis).

The analysis essentially confirms the same observation, only setting 3 using softplus for all positive transformations retains original accuracy (see 2D histogram below for pymc3 and the notebook for pyro and numpyro):

image

  1. Exp leads to exploding ELBO (see plot below).
  2. Softplus for scales improves the stability of ELBO (consistent with findings reported in Softplus transform as a more numerically stable way to enforce positive constraint numpyro#855) but, surprisingly, has lower accuracy compared to pymc3 (especially clearly seen in 2D histograms). I did not use pyro.plates when I did the original experiments (with numpyro version 0.4.1, I did not use pyro).
  3. Softplus for all positive transformations gives accuracy similar to the original.

image

@fritzo what do you think? Please let me know if you have any question and if I need to give more explanations in the notebook.

This model can be described as a GLM/non-negative factor analysis where factor loadings for variables are provided as fixed and the goal is to learn loadings for observations. I can also do the analysis of ELBO stability using a model where factor loadings for variables also need to be learned. I expect this trend to be stronger because the model is less constrained - but I do not have the same ground truth for evaluating accuracy.

@vitkl
Copy link
Contributor Author

vitkl commented Apr 13, 2021

The same pattern (ELBO stability) and accuracy are also reproduced with 5x larger data.

image

@fehiepsi
Copy link
Member

@vitkl Thanks for setting the experiments! This is so cool. I would like to walk through this to understand better the behavior. I am a bit busy this week so will get back to you sometime early next week. :)

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vitkl, these experiments look great! @fehiepsi and I discussed and agreed that:

  1. Your experiments are sufficient evidence to change autoguide scale parameters to use softplus transforms by default.
  2. While it would be too big a change to alter the default transform for constraints.positive, we would be happy to add documentation on how change the default, e.g. adding to the new Tips & Tricks tutorial.

@fritzo
Copy link
Member

fritzo commented Apr 25, 2021

@vitkl is this ready to merge?

@vitkl
Copy link
Contributor Author

vitkl commented Apr 26, 2021 via email

@fritzo fritzo added awaiting review and removed WIP labels May 4, 2021
@fritzo
Copy link
Member

fritzo commented May 4, 2021

@vitkl mind if I merge this?

@vitkl
Copy link
Contributor Author

vitkl commented May 10, 2021

Thanks for your patience! I just cleaned up the notebooks a bit. I think you can merge this.

@fritzo fritzo changed the title [WIP] Comparing exp and softplus transform using synthetic data (cell2location model) Compare exp and softplus transform using synthetic data (cell2location model) May 10, 2021
@fritzo fritzo merged commit 723ebd0 into pyro-ppl:master May 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants