Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse logistic regression with spike slab prior example #247

Open
gaow opened this issue May 17, 2020 · 29 comments
Open

Sparse logistic regression with spike slab prior example #247

gaow opened this issue May 17, 2020 · 29 comments
Assignees

Comments

@gaow
Copy link

gaow commented May 17, 2020

I am trying to adapt #154 to using a spike-slab prior pi * delta + (1 - pi) * N(mu, sigma) where pi, mu and sigma are given and a uniform prior will be used for the intercept. I came up with the code below, but I get an error from the p = tf.linalg.matvec(X, xi * beta) line:

tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a int32 tensor but is a float tensor [Op:Mul]

If we explicitly cast type, p = tf.linalg.matvec(X, tf.cast(xi, tf.float32) * beta) , this error above is gone. But the code below still doesn't work:

ValueError: Encountered `None` gradient.
  fn_arg_list: [<tf.Tensor 'init:0' shape=(3, 3) dtype=int32>, <tf.Tensor 'init_1:0' shape=(3, 3) dtype=float32>, <tf.Tensor 'init_2:0' shape=(3,) dtype=float32>]
  grads: [None, <tf.Tensor 'mcmc_sample_chain/dual_averaging_step_size_adaptation___init__/_bootstrap_results/NoUTurnSampler/.bootstrap_results/process_args/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/dual_averaging_step_size_adaptation___init__/_bootstrap_results/NoUTurnSampler/.bootstrap_results/process_args/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/PartitionedCall/pfor/PartitionedCall_grad/PartitionedCall:0' shape=(1, 3) dtype=float32>, <tf.Tensor 'mcmc_sample_chain/dual_averaging_step_size_adaptation___init__/_bootstrap_results/NoUTurnSampler/.bootstrap_results/process_args/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/dual_averaging_step_size_adaptation___init__/_bootstrap_results/NoUTurnSampler/.bootstrap_results/process_args/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/PartitionedCall/pfor/PartitionedCall_grad/PartitionedCall:1' shape=(1,) dtype=float32>]

I'm wondering,

  1. How do we properly "mask" variables (beta) here? In pymc3 it is just p = pm.math.dot(X, xi * beta). In pymc4 is explicit variable casting the best thing to do, or there are better ways to implement it?
  2. Any insights to the error message? It suggests the first parameter xi gets gradient None but I'm not sure what's going on.
  3. Additionally any suggestions on how I can improve my code below is very much appreciated. In particular, do we really have to use those tf.zeros() and tf.ones() calls to create model, or there is a better way to specify the dimension of the parameters?

Thanks in advance for your input!

Code and data to reproduce the problem

Data: issue_247.tar.gz

The model,

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
import pymc4 as pm

@pm.model
def get_model(y, X, pi0=0.5, mu=0, sigma=1, lower=-1, upper=1):
    xi = yield pm.Bernoulli('xi', pi0 * tf.ones(X.shape[1], 'float32'))
    beta_offset = yield pm.Normal('beta_offset', tf.zeros(X.shape[1], 'float32'), tf.ones(X.shape[1], 'float32'))
    beta = yield pm.Deterministic("beta", mu + beta_offset * sigma)
    alpha_offset = yield pm.Uniform("alpha_offset", -1, 1)
    alpha = yield pm.Deterministic("alpha", lower + (alpha_offset + 1) / 2 * (upper - lower))
    # this is the line in question; without explicit casting,
    # p = tf.linalg.matvec(X,  xi * beta)
    # will not work
    p = tf.linalg.matvec(X, tf.cast(xi, tf.float32) * beta)
    yield pm.Bernoulli('y_obs', tf.math.sigmoid(p+alpha), observed = y)

Load data and set parameters,

import pickle
data = pickle.load(open("issue_247.pkl", "rb"))

iteration = 1000
tune_prop = 0.25
n_chain = 3

pi0 = 0.051366009925488
mu = 0.783230896500752
sigma = 0.816999481742865
lower = -2.94
upper = 0

Inference,

model = get_model(data['y'], tf.constant(data['X']), pi0, mu, sigma, lower, upper)
trace = pm.inference.sampling.sample(
        model, step_size=0.01, num_chains=n_chain, num_samples=iteration,
        burn_in=int(tune_prop*iteration), nuts_kwargs={},
        xla=False, use_auto_batching=True
    )
@Sayam753
Copy link
Member

Hi @gaow
Looking at the error message, it seems that there is incompatibility between dtype of X and xi * beta passed to tf.linalg.matvec function. What I think, even though specifying float32 type for probability of events, tfp does not properly broadcast the type for Bernoulli distribution.

>>> xi = pm.Bernoulli('xi', probs=0.5*np.ones((20), 'float32'))
>>> xi.dtype
tf.int32

So, PyMC4 needs to account for dtype to be passed in for distributions. And that's how tfp can be back referenced to explicitly use dtypes.

@gaow
Copy link
Author

gaow commented May 17, 2020

Thanks @Sayam753 yes xi is an integer and it indeed makes sense that some automatic variable casting should be performed ... To consolidate information I updated my original post with an explicit type cast followed by additional questions.

@Sayam753
Copy link
Member

Sayam753 commented May 18, 2020

Hi @gaow
I am not sure, if tf.cast can cast a dtype to distributions. Rather I will opt to explicitly pass dtype as an argument to Bernoulli. But this argument will not get passed to tfp due to PyMC4 design.

xi = pm.Bernoulli('xi', probs=0.5*np.ones((20), 'float32'), dtype=tf.float32)

It maybe a good idea that PyMC4 supports more control over tfp.distributions.
Pinging @lucianopaz , to provide insights how distributions are handled.

@lucianopaz
Copy link
Contributor

@gaow, @Sayam753, the error that is being raised comes from tensorflow's type promotion rules. The problem is that an int32 can't be promoted to a float32, because float32's can't represent all the integer numbers that can be represented with an int32. The users must manually and unsafely cast xi to a float32 or to an int16. Jax has a table of the allowed type promotions that seems to be the same as tensorflow's.
About the distributions dtype, yes there is no built-in mechanism to do this. It had come up on another issue (#236) and there is an open PR (#239) that tries to add a way to set some of the tfp distribution parameters. It stalled when I asked to make it more generic in order to add more arguments as we advanced pymc4 development.

@gaow
Copy link
Author

gaow commented May 18, 2020

@lucianopaz Thanks you for your feedback. I believe @Sayam753 was suggesting adding support to cast data type of return from distributions. #236 seems relevant to input parameter check, so even with a better version of #239 I still don't see how explicit type casting might work here.

I suspect the 2nd error message I ran into after explicitly cast types is artifact of type casting like this. But I don't see another way to implement my model without proper type casting support.

(perhaps I am wrong and the 2nd error has nothing to do with the type cast? In that case @lucianopaz is there someone you can pin to help looking at the model? I'm pinning @kyleabeauchamp from #154 see if there is some insight to it -- thanks in advance!)

@lucianopaz
Copy link
Contributor

@gaow, I hadn't seen the updated error. You are getting a gradient is None because xi is a discrete variable and it cannot be differentiated. Currently, pymc4 only performs NUTS sampling, which only works if all of your unobserved variables are continuous. @rrkarim is working on #229 that will enable pymc4 to sample from your model by performing some nested steps (Gibbs steps for the discrete variables followed by a NUTS step for the rest of the continuous variables, or something else).

@gaow
Copy link
Author

gaow commented May 18, 2020

Thank you @lucianopaz for the clarification. I suppose I will wait for #229 to work then try implementing this example? It would be nice if #239 is made generic enough to also support variable type specifications, too. Please advice if there is anything else I can do at this point. Thanks a lot!

@rrkarim
Copy link
Contributor

rrkarim commented May 18, 2020

@gaow @lucianopaz I'm not sure when #229 will be merged to the master since there were some priority changes lately. But I will try to provide some documentation on the compound step usage there soon. Then it could be used as an experimental feature.

@gaow
Copy link
Author

gaow commented May 19, 2020

@rrkarim thank you! I tried to install your branch and modified my code to the following:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import tensorflow as tf
import numpy as np
import pymc4 as pm
from pymc4.mcmc.samplers import NUTS, RandomWalkM

@pm.model
def get_model(y, X, pi0=0.5, mu=0, sigma=1, lower=-1, upper=1):
    xi = yield pm.Bernoulli('xi', pi0 * tf.ones(X.shape[1], 'float32'))
    beta_offset = yield pm.Normal('beta_offset', tf.zeros(X.shape[1], 'float32'), tf.ones(X.shape[1], 'float32'))
    beta = yield pm.Deterministic("beta", mu + beta_offset * sigma)
    alpha_offset = yield pm.Uniform("alpha_offset", -1, 1)
    alpha = yield pm.Deterministic("alpha", lower + (alpha_offset + 1) / 2 * (upper - lower))
    # this is the line in question; without explicit casting,
    # p = tf.linalg.matvec(X,  xi * beta)
    # will not work
    p = tf.linalg.matvec(X, tf.cast(xi, tf.float32) * beta)
    yield pm.Bernoulli('y_obs', tf.math.sigmoid(p+alpha), observed = y)

import pickle
data = pickle.load(open("issue_247.pkl", "rb"))

iteration = 1000
tune_prop = 0.25
n_chain = 3
n_thread = 4

pi0 = 0.051366009925488
mu = 0.783230896500752
sigma = 0.816999481742865
lower = -2.94
upper = 0

tf.config.threading.set_intra_op_parallelism_threads(n_thread)

model = get_model(data['y'], tf.constant(data['X']), pi0, mu, sigma, lower, upper)
trace = pm.sample(
        model, step_size=0.01, num_chains=n_chain, num_samples=iteration,
        burn_in=int(tune_prop*iteration), nuts_kwargs={},
        xla=False, use_auto_batching=True,
        sampler_type="compound", 
        sampler_methods=[
        ("xi", RandomWalkM), 
        ("beta", NUTS),
        ("alpha", NUTS),
        ("beta_offset", NUTS),
        ("alpha_offset", NUTS)
        ]
    )

Unfortunately it failed quickly,

Traceback (most recent call last):
  File "/opt/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py", line 324, in _AssertCompatible
    fn(values)
  File "/opt/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py", line 263, in inner
    _ = [_check_failed(v) for v in nest.flatten(values)
  File "/opt/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py", line 264, in <listcomp>
    if not isinstance(v, expected_types)]
  File "/opt/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py", line 248, in _check_failed
    raise ValueError(v)
ValueError: 1.0

I'd appreciate your feedback on this to get the sparse regression example work. Please take your time. I hope this is something not hard to fix.

@rrkarim
Copy link
Contributor

rrkarim commented May 19, 2020

@gaow current code doesn't support sampling from bernoulli distribution since there is no defined op for proposal generation. For categorical distribution we have it in state_functions.py. It is pretty easy to add, I'm working on it.

@rrkarim
Copy link
Contributor

rrkarim commented May 19, 2020

@gaow you can check it on the head commit now. Can you please share if there are some issues. Still, I wouldn't encourage to use the branch code on experiments and etc. The progress on it is still in early stages so there could be many potential bugs.

@gaow
Copy link
Author

gaow commented May 19, 2020

Thank you @rrkarim for your prompt response! The code works on a smaller data-set if I replace the above data y and X with something smaller:

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()
X = data["data"].astype('float32')
# Standardize to avoid overflow issues
X -= X.mean(0)
X /= X.std(0)
y = data["target"]

n_samples, n_features = X.shape

X = tf.constant(X)

"works" means it completed without error messages. I am currently running that larger data I posted in this ticket -- will analyze its output to see if the estimated parameters are in line with the simulation truth. I will follow up on that.

Still, I wouldn't encourage to use the branch code on experiments and etc. The progress on it is still in early stages so there could be many potential bugs.

Duly noted. Thank you! Still this might be the most promising way out at this point because I've tried pymc3 for that model it uses so much resource for the scale of my problem although the result of inference is somewhat okay. I'll check the resource usage and correctness of this implementation if it is promising. Hopefully some of our preliminary experiments can help with identifying bugs.

@gaow
Copy link
Author

gaow commented May 19, 2020

@rrkarim Here is result in HTML format running the example data in the first post of this ticket. I explained in the notebook how data is generated and what result I'm expected to see,

20200519_PyMC4_explore.zip

It is faster and uses a lot less memory compared to PyMC3 on my computer. The result makes sense for some of the parameters. But it is off for some other parameters and overall doesn't look as good as that from PyMC3.

Do you (or does anyone here) see anything obvious wrong with my code that implements the spike slab logistic regression model?

@junpenglao
Copy link
Member

We dont yet have a good tuning strategy, so you will likely see poorer quality samples from PyMC4 right now - we are working better tuning strategy, but for now you can try assigning different step size to different RVs (passing a tuple/list of step_size proportion to posterior standard deviation) should improve the performance quite a bit.

@rrkarim
Copy link
Contributor

rrkarim commented May 19, 2020

@gaow I see your prior distribution for \alpha is uniform while in kaggle's pymc3 implementation it is a normal distribution. The point is that I wasn't able to draw strict parallels between two model implementations.

Also, just for the note, for now compound performs an unconditional step (compound step, not gibbs), but it should be clear from the sampler name. I don't think the choice of sampler is important in this model though (I might be wrong). If you want to perform gibbs step you can just modify _target_log_prob_fn_part to return modified state part, here is the def of function for compound step:

def _target_log_prob_fn_part(state_part, idx):
temp_value = state[idx]
state[idx] = state_part
log_prob = self._target_log_prob_fn(*state)
state[idx] = temp_value
return log_prob

And yeah to the @junpenglao point, there is no tuning strategy implemented, so parameters should be hard coded for now.

And also, you can enable xla for sampling, it is way faster. (was when I was testing some time ago)

@gaow
Copy link
Author

gaow commented May 19, 2020

@rrkarim sorry I forgot to mention I do have implemented uniform alpha with pymc3. Let me put together soon a more formal notebook with a smaller example analyzed using both pymc3 and pymc4, and push it somewhere so we can easily compare the differences. We can then fiddle with sampler etc.

I don't think I understand @junpenglao's pointer of how to manually tuning these parameters (how do I know what is the right step_size to set?). I guess it will be easier for us to discuss after I provide such an example. Will get back on that soon!

@gaow
Copy link
Author

gaow commented May 20, 2020

@rrkarim @junpenglao here is a notebook i posted on Google drive implementing the same model using both pymc4 and pymc3. It simulates some data and analyze it so you can see pymc4 result is way off for some quantities.

To open the notebook please click on below:

https://drive.google.com/file/d/161KAaWM-ur6PaqfhNUqpoJMePu8-EosA/view?usp=sharing

For those haven't worked with colab -- Google drive should prompt you to either download or "Open with" another app by "Connecting to another app". Choose that option and search for "colab", then click "connect". You should then be able to open the notebook online and to comment on it (I set the permission to "anyone can comment").

Hopefully it is a useful example and can be fixed in PyMC4!

@rrkarim
Copy link
Contributor

rrkarim commented May 22, 2020

@gaow thank you for the notebook. I see the results now are far off. I should dive into the issue to solve it. @junpenglao you can assign this to me.

@rrkarim
Copy link
Contributor

rrkarim commented Jul 31, 2020

@gaow can you test the model on the newest PR on compound step support? #306
You don't need to define sampler for each variable, you can just write:

trace = pm.sample(
        model, step_size=0.01, num_chains=n_chain, num_samples=iteration,
        burn_in=int(tune_prop*iteration), nuts_kwargs={},
        xla=True, use_auto_batching=True,
        sampler_type="compound", 
        sampler_methods=[
        ("xi", RandomWalkM), 
        ]
    )

Here are my results on the new code: results_notebook.

@gaow
Copy link
Author

gaow commented Jul 31, 2020

@rrkarim thanks a lot! Result in your notebook seems promising. However I am having issues running the code ... if I use tf-nightly via pip install tf-nightly -U I get an error when importing pymc4:

>>> import pymc4 as pm4
  File "/opt/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/internal/distribution_util.py", line 24, in <module>
    import tensorflow.compat.v2 as tf
ModuleNotFoundError: No module named 'tensorflow.compat'

If I use tensorflow via pip install tensorflow tensorflow-probability I get version 2.3.0. Then import pymc4 as pm4 will work, but when I run the pm.sample() function as in your above thread, I get:

    /opt/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py:444 make_tensor_proto
        raise ValueError("None values not supported.")

    ValueError: None values not supported.

I do seem to recall having lots of issues with pymc4 and tensorflow compatibility. It seems there is this issue again -- hopefully you can reproduce it by updating to the latest tf-nightly and see?

@rrkarim
Copy link
Contributor

rrkarim commented Jul 31, 2020

Yeap, I'm really sorry, my mistake. I've fixed it. I think it would be better to wait for 1-2 days so I can fix some issues.

@gaow
Copy link
Author

gaow commented Jul 31, 2020

Thanks @rrkarim but still,

/opt/miniconda3/lib/python3.7/site-packages/pymc4/distributions/continuous.py in <module>
      4 import numpy as np
      5 import tensorflow as tf
----> 6 from tensorflow_probability import distributions as tfd
      7 from tensorflow_probability import bijectors as bij
      8 from tensorflow_probability.python.internal import distribution_util as dist_util

ImportError: cannot import name 'distributions' from 'tensorflow_probability' (unknown location)

However

>>> import tensorflow_probability 
>>> 

works.

$ pip show tfp-nightly
Name: tfp-nightly
Version: 0.12.0.dev20200730
Summary: Probabilistic modeling and statistical inference in TensorFlow
Home-page: http://github.com/tensorflow/probability

is the version I used for tensorflow_probability. Hopefully it is not a hard fix. Thanks!

@rrkarim
Copy link
Contributor

rrkarim commented Jul 31, 2020

I'm note sure about this issue, can you test from tensorflow_probability import distributions as tfd. It should be imported in 0.12.0-dev20200731 If not maybe upgrade nightly build.

@gaow
Copy link
Author

gaow commented Jul 31, 2020

@rrkarim interestingly, version 0.12.0-dev20200730 does not work, but version 0.12.0-dev20200731 does.

However after upgrading both tf-nightly and tfp-nightly to todays build, I am now back to getting this same error as I run it on the release version of tensorflow.

    /opt/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py:444 make_tensor_proto
        raise ValueError("None values not supported.")

    ValueError: None values not supported.

This is running your notebook here. Could you reproduce it? Thanks!

@rrkarim
Copy link
Contributor

rrkarim commented Jul 31, 2020

Have you updated the code in PR? The notebook is on:

print(tf.__version__)
print(tfp.__version__)
2.4.0-dev20200731
0.12.0-dev20200731

[UPD] Ok, check it with xla=False. XLA worked with previous builds of tfp (or tf). I will return to the bug later.

@gaow
Copy link
Author

gaow commented Jul 31, 2020

@rrkarim I am update with the PR and I have the same version of tf and tfp as you have (built today). Thanks for checking the xla bug! I confirm that xla=False works.

@rrkarim
Copy link
Contributor

rrkarim commented Aug 27, 2020

@gaow Here are the results for the toy dataset you have provided. Are there any other issue?

@gaow
Copy link
Author

gaow commented Aug 27, 2020

Thanks @rrkarim this notebook works as expected after I upgrade to your current master. Only complaint for now is that compared to the same implementation in pymc3, the current pymc4 code is much slower, as you can test and find out in the notebook above.

@rrkarim
Copy link
Contributor

rrkarim commented Aug 30, 2020

@gaow yeap, I will analyze it more, I'm not sure if it easily solvable though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants