-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reweighting utilities #654
Conversation
…ential aside: nice way to express MBAR
Co-authored-by: Matt Wittmann <mwittmann@relaytx.com>
Applying suggestion from #654 (comment) Co-Authored-By: Matt Wittmann <mcwitt@gmail.com>
step through sign flips, log conversion, ignored constant prefactor, and implicit rank-promotion (initially inspired by avoiding a mypy complaint about a transpose #654 (comment) )
Addressing #654 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a first pass of the tests - will resume adding more comments on the actual implementation after dinner.
f_ref, g_ref = value_and_grad(analytical_delta_f)(trial_params) | ||
|
||
onp.testing.assert_allclose(f_hat, f_ref, atol=atol) | ||
onp.testing.assert_allclose(g_hat, g_ref, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should we think about precision of dDeltaF/dParams
both for the gaussian case, and in general? For the gaussian case, I'd imagine that dDeltaF/dlog_sigma
might require more precision than dDeltaF/dmean
. Presumably, an acceptable level of error is proportional to the internal learning rate one might take when optimizing the parameters, so maybe rtol
is more approriate here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to think about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the gaussian case, I'd imagine that dDeltaF/dlog_sigma might require more precision than dDeltaF/dmean.
I'd imagine so.
In a documentation notebook of #404 , cell [46], note that estimates of the different gradient components have different variance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should we think about precision of dDeltaF/dParams both for the gaussian case, and in general?
Preferred interpretation: Exact gradient of a randomized approximation of DeltaF(params)
.
delta_f_approx_fxn = construct_approximation(ref_params, random_seed)
Presumably, an acceptable level of error is proportional to the internal learning rate one might take when optimizing the parameters
In the case that you created a fresh random approximation every optimization step
delta_f_approx_fxn_0 = construct_approximation(params_0, seed_0)
params_1 = params_0 + stepsize * grad(delta_f_approx_fxn_0)(params_0)
delta_f_approx_fxn_1 = construct_approximation(params_1, seed_1)
params_2 = params_1 + stepsize * grad(delta_f_approx_fxn_1)(params_1)
...
then appropriate step size would be a function of both the geometry of delta_f_exact_fxn
and the precision of the gradient estimate.
Instead, I think we would want to perform iterates like
delta_f_approx_fxn_0 = construct_approximation(params_0, seed_0)
params_1 = local_optimize(fun=delta_f_approx_fxn_0, x0=params_0, ...)
delta_f_approx_fxn_1 = construct_approximation(params_1, seed_1)
params_2 = local_optimize(fun=delta_f_approx_fxn_1, x0=params_1, ...)
...
where local_optimize(fun, x0, ...)
returns the local optimum of fun
(+ some regularization as a function of x0
), optionally restricted to some trust region around x0
. (Glossing over how that trust region is defined.)
Then whatever step sizes etc. are used inside local_optimize
do NOT depend on the precision of the estimate, only on "the geometry of the problem."
The amount of progress you make in each outer loop step (construct_approximation, call local_optimize) would be influenced by how that trust region is defined.
(There's a similar way to view gradient descent: each step of SGD is exactly optimizing a random local approximation of f + some regularization.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, missed replying to this part:
so maybe rtol is more approriate here?
I don't think so.
isclose(estimated_gradient, exact_gradient, rtol=...)
might be appropriate if the noise magnitude of each component estimated_gradient[i]
was proportional to the magnitude of the corresponding gradient component abs(exact_gradient[i])
.
(Counterexample: 1D Lennard-Jones testsystem: exact gradient w.r.t. [sigma, epsilon]
parameters ~= [-0.15, -0.3]
(cell [43]), but estimates of these gradient components have marginal variance ~= [0.1, 0.01]
(cell [41]).)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to leave this "unresolved" so I can find the comments above easier. But this mostly addresses my concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened #667 for continued discussion after this PR
batched_u_0_fxn: BatchedReducedPotentialFxn, | ||
batched_u_1_fxn: BatchedReducedPotentialFxn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we use batched_
pretty extensively in other parts of the code too
|
||
def endpoint_correction_0(params) -> float: | ||
"""estimate f(ref, 0) -> f(params, 0) by reweighting""" | ||
delta_us = batched_u_0_fxn(samples_0, params) - ref_u_0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd probably need to pin samples to a DeviceBuffer
or something to truly see any meaningful performance gains here.
Addressing #654 (comment)
Addressing #654 (comment)
Previous change to make ref_params the same for both 1D tests ( 475a399 ) causes: > onp.testing.assert_allclose(g_hat, g_ref, atol=atol) E AssertionError: E Not equal to tolerance rtol=1e-07, atol=0.001 E E Mismatched elements: 1 / 2 (50%) E Max absolute difference: 0.0011654208166751 E Max relative difference: 0.0011654208166751 E x: array([-2.282977e-04, -9.988346e-01]) E y: array([ 0., -1.])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR lgtm - especially after some in person discussions last week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! The docstrings, comments, and type annotations are all very useful for understanding
Adds functions for two kinds of differentiable reweighting (
construct_endpoint_reweighting_estimator
,construct_mixture_reweighting_estimator
).These are tested for correctness on a 1D system (where comparisons to exact free energies and gradients are possible).
Less stringent tests confirm that the same implementation is compatible with custom_ops in an absolute hydration free energy test system.
Notes to reviewers:
reweighting.py
(most of the total line count is in tests, and most of the lines inreweighting.py
are documentation / white space)batched_u_fxn
, rather than au_fxn
which can be transformed byjax.vmap
), and has a more generic interface.log_weights
required by one of these approaches can be computed from a collection of sampled states using MBAR viainterpret_as_mixture_potential
.u_kn
matrix in hand, there would be no reason to prefer, say, TI's estimate off_k
overmbar.f_k
).log_weights
, that would not require the fullu_kn
matrix.log_weights
.