Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions regarding changes to JaxRDDLCompiler - params to jax_expr #186

Closed
pecey opened this issue Mar 30, 2023 · 3 comments
Closed

Questions regarding changes to JaxRDDLCompiler - params to jax_expr #186

pecey opened this issue Mar 30, 2023 · 3 comments

Comments

@pecey
Copy link
Contributor

pecey commented Mar 30, 2023

There has been some changes in JaxRDDLCompiler which has altered the lambda functions returned for CPF evaluation. In the file the logic is jax_cpfs[cpf] = self._jax(expr, info, dtype=dtype). Earlier expr was a lambda function that had two params - 1. dictionary of state, action and interim variables with their current values and 2. a PRNG key.

Now I think it is expecting three values. I couldn't find documentation of what the three params should be. Can someone please let me know where should I be looking or just explain what the three params that expr expects now?

Thank you.

@mike-gimelfarb
Copy link
Collaborator

Hi,

If I am understanding correctly, you are referring to the new 'params' argument in the wrapped jax expressions for RDDL calculations.

In short, this refers to a dictionary of per-node weight parameters that provide fine control of the model relaxations for discrete calculations approximated by parameterized expressions, aka sigmoid.

You can currently define and tune these per node weights if you like (e.g. using Bayesian optimizataion). There is an unused function 'print_parameterized_exprs' in the JaxExample that you could call to retrieve the keys it expects as well as their current values.

===

The story behind this is that, in the previous version, there was a single global tuning parameter 'w' to control the accuracy of the relaxations, e.g. x >= y -> sigmoid(w * (x - y)) in FuzzyLogic. However, in principle, it is possible to use per-node weight parameters where each 'w' can be locally tuned, e.g. using some local errors. We did not want to limit users who wish to adapt these parameters and have better control over the model approximation. It is not really used anywhere, nor is it currently clear how to without having better control of intermediate calculation in jax, something we like to work on in the future. (FuzzyLogic currently defines these parameters for some relaxations, so you can look there to see the technical details how they are propagated.)

@mike-gimelfarb
Copy link
Collaborator

In the future, 'params' could also be used for propagating other information/parameters through the computation graph that one does not want to bake in, so it is really meant as a "catch-all" for propagating information through Jax.

@pecey
Copy link
Contributor Author

pecey commented Mar 31, 2023

Thank you for the pointers @mike-gimelfarb. I will have a look at them.

@pecey pecey closed this as completed Mar 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants