Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Improve debug docs and function helpers #6841

Open
Tracked by #7053
twiecki opened this issue Aug 1, 2023 · 12 comments
Open
Tracked by #7053

ENH: Improve debug docs and function helpers #6841

twiecki opened this issue Aug 1, 2023 · 12 comments

Comments

@twiecki
Copy link
Member

twiecki commented Aug 1, 2023

Before

x = pm.Normal("x")
x_print = Print("x")(x)

After

x = pm.Normal("x", debug=True)

Context for the issue:

On https://www.pymc.io/projects/examples/en/latest/howto/howto_debugging.html we show how pytensor.printing.Print can be used to debug-print RVs which is helpful for debugging model issues. This could be done with nicer API if we add a debug kwarg that would wrap the RV internally with a Print Op.

@twiecki twiecki changed the title ENH: Better debug support for RVs ENH: Better debug support for RVs with debug kwarg Aug 1, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Aug 2, 2023

It seems like the notebook is doing its job?

I don't think we should add a debug kwarg to RVs: It's a bit vague what it means and it's not really discoverable anyway (none of the distribution arguments are). I rather have Distribution do less things than more.

Instead I would propose to:

  1. Update the notebook
  2. Add an initial point failed example and show model.debug
  3. Link to the notebook in model.debug output, for the cases where that is not sufficient and the print thing may be better.
  4. Add a nicer print_value helper so users don't have to initialize the Print Op manually which is certainly weird:
def print_value(var, name=None):
  """Print value of variable when it is computed during sampling.

  This is likely to affect sampling performance.
  """
  if name is None:
    name = var.name
  return Print(name)(var)
  1. Implement PrintOp in JAX (I think it can be done with https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html) for numpyro/blackjax based samplers

@twiecki
Copy link
Member Author

twiecki commented Aug 2, 2023

@ricardoV94 I like all these suggestions.

@ricardoV94 ricardoV94 changed the title ENH: Better debug support for RVs with debug kwarg ENH: Improve debug docs and function helpers Aug 2, 2023
@itsdivya1309
Copy link

I would like to work on this issue.

@twiecki
Copy link
Member Author

twiecki commented Feb 1, 2024

Great @itsdivya1309, do you have any questions on how to get going? Otherwise, feel free to open draft PR and we can take it from there.

@itsdivya1309
Copy link

Correct me if I am wrong, but I need to update this notebook as suggested above, right?

@twiecki
Copy link
Member Author

twiecki commented Feb 1, 2024

@itsdivya1309 Correct.

@itsdivya1309
Copy link

itsdivya1309 commented Feb 3, 2024

I don't understand what you mean by 'Implement PrintOp in JAX'. Can you please explain.
Also, the print_value() function anyways uses the Print Op, which means its anyways being initialized.

@twiecki
Copy link
Member Author

twiecki commented Feb 5, 2024

@itsdivya1309 You can treat that as a separate issue. As you can see in the NB, we're using Print() to output debug values. Print is implemented for the c-backend, but not currently for the JAX backend. In JAX, this functionality seems to be here https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html.

But you can just do the changes to the NB for now.

@itsdivya1309
Copy link

Alright

@AryanNanda17
Copy link
Contributor

@twiecki, I am also interested in this issue and it is not assigned to anyone. Can I open a Pr since no pr is opened?

@AryanNanda17
Copy link
Contributor

My understanding of the issue:-

  • You suggested to introduce a new debug keyword argument to the pm.Normal distribution, which, when set to True, internally wraps the random variable with a print operation.
  • @ricardoV94 suggested to use a function print_value instead to hide the weird nature of PrintOP
  • He also suggested to add some examples to the notebook.
    And mainly to implement PrintOP in jax.

@itsdivya1309
Copy link

itsdivya1309 commented Feb 9, 2024

@AryanNanda17 you can work on the JAX part if you want.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants