Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reweighting utilities #654

Merged
merged 56 commits into from
Mar 14, 2022
Merged

Reweighting utilities #654

merged 56 commits into from
Mar 14, 2022

Conversation

maxentile
Copy link
Collaborator

Adds functions for two kinds of differentiable reweighting (construct_endpoint_reweighting_estimator, construct_mixture_reweighting_estimator).

These are tested for correctness on a 1D system (where comparisons to exact free energies and gradients are possible).

Less stringent tests confirm that the same implementation is compatible with custom_ops in an absolute hydration free energy test system.


Notes to reviewers:

  • Length -- Although the raw number of lines added is pretty large, some consolation is that there are only ~40-50 lines of Python statements in reweighting.py (most of the total line count is in tests, and most of the lines in reweighting.py are documentation / white space)
  • Comparison to Differentiable reweighting #404 -- Structured differently (free functions rather than classses), omits some features (no effective sample size diagnostic), is less jax-specific (assumes access to a batched_u_fxn, rather than a u_fxn which can be transformed by jax.vmap), and has a more generic interface.
    • The more generic interface attempts to clarify relation to MBAR:
      • Both of the reweighting estimators in this PR are compatible with inputs that are not necessarily processed by MBAR.
      • The log_weights required by one of these approaches can be computed from a collection of sampled states using MBAR via interpret_as_mixture_potential.
      • The tests assert the approach is compatible with other estimators in principle (although if we have the full u_kn matrix in hand, there would be no reason to prefer, say, TI's estimate of f_k over mbar.f_k).
      • A reference is included about other ways to get (higher-variance estimates) of these log_weights, that would not require the full u_kn matrix.
      • In the future, if we sampled from a mixture distribution with exactly known logpdf (rather than inferring the mixture logpdf from multiple states via MBAR), that would also supply us with these log_weights.
  • Test structure -- there's some avoidable repetition in the tests, happy to re-organize this
  • Type annotations -- added in this commit: bff4ae5 -- not sure if these clarify or just add noise, happy to pare down
  • Docstrings -- didn't fully expand parameter descriptions, ASCII-art diagrams are currently a bit ugly

maxentile and others added 2 commits February 28, 2022 14:41
Co-authored-by: Matt Wittmann <mwittmann@relaytx.com>
Applying suggestion from #654 (comment)

Co-Authored-By: Matt Wittmann <mcwitt@gmail.com>
tests/test_reweighting.py Outdated Show resolved Hide resolved
tests/test_reweighting.py Outdated Show resolved Hide resolved
Copy link
Owner

@proteneer proteneer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a first pass of the tests - will resume adding more comments on the actual implementation after dinner.

tests/test_reweighting.py Show resolved Hide resolved
tests/test_reweighting.py Show resolved Hide resolved
tests/test_reweighting.py Show resolved Hide resolved
f_ref, g_ref = value_and_grad(analytical_delta_f)(trial_params)

onp.testing.assert_allclose(f_hat, f_ref, atol=atol)
onp.testing.assert_allclose(g_hat, g_ref, atol=atol)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we think about precision of dDeltaF/dParams both for the gaussian case, and in general? For the gaussian case, I'd imagine that dDeltaF/dlog_sigma might require more precision than dDeltaF/dmean. Presumably, an acceptable level of error is proportional to the internal learning rate one might take when optimizing the parameters, so maybe rtol is more approriate here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to think about this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the gaussian case, I'd imagine that dDeltaF/dlog_sigma might require more precision than dDeltaF/dmean.

I'd imagine so.

In a documentation notebook of #404 , cell [46], note that estimates of the different gradient components have different variance.

Copy link
Collaborator Author

@maxentile maxentile Mar 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we think about precision of dDeltaF/dParams both for the gaussian case, and in general?

Preferred interpretation: Exact gradient of a randomized approximation of DeltaF(params).

delta_f_approx_fxn = construct_approximation(ref_params, random_seed)

Presumably, an acceptable level of error is proportional to the internal learning rate one might take when optimizing the parameters

In the case that you created a fresh random approximation every optimization step

delta_f_approx_fxn_0 = construct_approximation(params_0, seed_0)

params_1 = params_0 + stepsize * grad(delta_f_approx_fxn_0)(params_0)

delta_f_approx_fxn_1 = construct_approximation(params_1, seed_1)

params_2 = params_1 + stepsize * grad(delta_f_approx_fxn_1)(params_1)
...

then appropriate step size would be a function of both the geometry of delta_f_exact_fxn and the precision of the gradient estimate.

Instead, I think we would want to perform iterates like

delta_f_approx_fxn_0 = construct_approximation(params_0, seed_0)

params_1 = local_optimize(fun=delta_f_approx_fxn_0, x0=params_0, ...)

delta_f_approx_fxn_1 = construct_approximation(params_1, seed_1)

params_2 = local_optimize(fun=delta_f_approx_fxn_1, x0=params_1, ...)
...

where local_optimize(fun, x0, ...) returns the local optimum of fun (+ some regularization as a function of x0), optionally restricted to some trust region around x0. (Glossing over how that trust region is defined.)

Then whatever step sizes etc. are used inside local_optimize do NOT depend on the precision of the estimate, only on "the geometry of the problem."

The amount of progress you make in each outer loop step (construct_approximation, call local_optimize) would be influenced by how that trust region is defined.

(There's a similar way to view gradient descent: each step of SGD is exactly optimizing a random local approximation of f + some regularization.)

Copy link
Collaborator Author

@maxentile maxentile Mar 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, missed replying to this part:

so maybe rtol is more approriate here?

I don't think so.

isclose(estimated_gradient, exact_gradient, rtol=...) might be appropriate if the noise magnitude of each component estimated_gradient[i] was proportional to the magnitude of the corresponding gradient component abs(exact_gradient[i]).

(Counterexample: 1D Lennard-Jones testsystem: exact gradient w.r.t. [sigma, epsilon] parameters ~= [-0.15, -0.3] (cell [43]), but estimates of these gradient components have marginal variance ~= [0.1, 0.01] (cell [41]).)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave this "unresolved" so I can find the comments above easier. But this mostly addresses my concerns.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #667 for continued discussion after this PR

tests/test_reweighting.py Outdated Show resolved Hide resolved
tests/test_reweighting.py Show resolved Hide resolved
tests/test_reweighting.py Show resolved Hide resolved
tests/test_reweighting.py Outdated Show resolved Hide resolved
timemachine/fe/reweighting.py Show resolved Hide resolved
timemachine/fe/reweighting.py Show resolved Hide resolved
Comment on lines +76 to +77
batched_u_0_fxn: BatchedReducedPotentialFxn,
batched_u_1_fxn: BatchedReducedPotentialFxn,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use batched_ pretty extensively in other parts of the code too


def endpoint_correction_0(params) -> float:
"""estimate f(ref, 0) -> f(params, 0) by reweighting"""
delta_us = batched_u_0_fxn(samples_0, params) - ref_u_0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd probably need to pin samples to a DeviceBuffer or something to truly see any meaningful performance gains here.

timemachine/fe/reweighting.py Show resolved Hide resolved
timemachine/fe/reweighting.py Outdated Show resolved Hide resolved
Previous change to make ref_params the same for both 1D tests ( 475a399 ) causes:

>       onp.testing.assert_allclose(g_hat, g_ref, atol=atol)
E       AssertionError:
E       Not equal to tolerance rtol=1e-07, atol=0.001
E
E       Mismatched elements: 1 / 2 (50%)
E       Max absolute difference: 0.0011654208166751
E       Max relative difference: 0.0011654208166751
E        x: array([-2.282977e-04, -9.988346e-01])
E        y: array([ 0., -1.])
Copy link
Owner

@proteneer proteneer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR lgtm - especially after some in person discussions last week.

Copy link
Collaborator

@mcwitt mcwitt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! The docstrings, comments, and type annotations are all very useful for understanding

@maxentile maxentile merged commit b5f1302 into master Mar 14, 2022
@maxentile maxentile deleted the differentiable-reweighting-redux branch June 15, 2022 13:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants