Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch TIBDExchangeMove #1283

Merged
merged 40 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
84afa21
Remove unused copy
badisa Mar 18, 2024
aa92ecb
Remove logsumexp of initial weights from compute_initial_weights
badisa Mar 15, 2024
6814507
Add --batch_size to water sampling example
badisa Mar 26, 2024
9bf027b
Batch Targeted Insertion
badisa Mar 19, 2024
96b1228
Put back the neutral case
badisa Mar 27, 2024
251d6b3
PR Feedback: Remove the changes from Sampler
badisa Mar 27, 2024
c6cddfc
PR Feedback: Clean up comments, remove prints
badisa Mar 27, 2024
23d9f99
PR Feedback: Add additional validation tests
badisa Mar 27, 2024
18d218c
PR Feedback: Clean up deterministic tests
badisa Mar 30, 2024
05248f9
PR Feedback: Fixup Metropolis-Hasting wording
badisa Mar 30, 2024
f8a5fa5
PR Feedback: Clarify description of TIBD algorithm
badisa Mar 30, 2024
2a033d6
PR Feedback: Add determinism test for buckyball
badisa Mar 30, 2024
dd1752e
PR Feedback: Last raw log probability
badisa Mar 30, 2024
abd731e
PR Feedback: Fix comment
badisa Apr 2, 2024
442697e
PR Feedback: Expose initial weights computation and weights
badisa Apr 2, 2024
1b20d26
Reduce memory usage of mol energy buffer used for initial weights
badisa Apr 2, 2024
f55cfe7
Speed up slow kernel
badisa Apr 2, 2024
87e77f6
Modify BDExchangeMover to store before weights at end
badisa Apr 2, 2024
a21e7c7
PR Feedback: Be more explicit about template types
badisa Apr 2, 2024
ae1ac07
PR Feedback: Missing comments
badisa Apr 3, 2024
c50f382
WIP: Improves precision
badisa Apr 2, 2024
ba5048e
PR Feedback: Move inv_kT for the final sample to last computation
badisa Apr 3, 2024
239271d
Keep all energies in fixed point to ensure bitwise identical log
badisa Apr 4, 2024
a0452bb
PR Feedback: Make test more robust
badisa Apr 4, 2024
678a80b
PR Feedback: Reduce threshold in test
badisa Apr 4, 2024
7e2e218
PR Feedback: Increase iterations of determinism test
badisa Apr 4, 2024
1248de9
PR Feedback: Be consistent about using volatile keyword
badisa Apr 5, 2024
eda228f
PR Feedback: Remove volatile from atomic shared values
badisa Apr 5, 2024
3933dba
PR Feedback: Undo volatile since we have __syncthreads which has memory
badisa Apr 5, 2024
45c49cb
PR Feedback: Fix up spelling of Metropolis-Hastings
badisa Apr 5, 2024
836f052
PR Feedback: Rename energy_accumulator
badisa Apr 5, 2024
d1aaf67
PR Feedback: Proposals per move wording
badisa Apr 5, 2024
78a4b39
PR Feedback: Use beta instead of inv_kT
badisa Apr 5, 2024
f22aad9
PR Feedback: Rename weights -> energies where applicable
badisa Apr 5, 2024
4b9bd1d
PR Feedback: Rewrite confusing sentence
badisa Apr 5, 2024
5b6ee8b
PR Feedback: Be explicit that it is indexes and not values
badisa Apr 5, 2024
6204453
PR Feedback: beta -> inv_beta to be explicit
badisa Apr 5, 2024
261f510
PR Feedback: Don't allow postive inf weights
badisa Apr 5, 2024
fc4c5f5
PR Feedback: Clarify variable name, log_weight
badisa Apr 5, 2024
638b242
PR Feedback: Be more careful about log_weights vs weights
badisa Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions examples/water_sampling_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def test_exchange():
parser.add_argument(
"--save_last_frame", type=str, help="Store last frame as a npz file, used to verify bitwise determinism"
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
proteneer marked this conversation as resolved.
Show resolved Hide resolved
help="Batch size to generate proposals for the MC moves, not used for reference case",
)

args = parser.parse_args()

Expand Down Expand Up @@ -156,6 +162,7 @@ def run_mc_proposals(mover, state: CoordsVelBox, steps: int) -> CoordsVelBox:
seed,
args.mc_steps_per_batch,
exchange_interval,
batch_size=args.batch_size,
)
elif args.insertion_type == "untargeted":
if args.use_reference:
Expand All @@ -171,6 +178,7 @@ def run_mc_proposals(mover, state: CoordsVelBox, steps: int) -> CoordsVelBox:
seed,
args.mc_steps_per_batch,
exchange_interval,
batch_size=args.batch_size,
)

cur_box = initial_state.box0
Expand Down Expand Up @@ -206,7 +214,8 @@ def run_mc_proposals(mover, state: CoordsVelBox, steps: int) -> CoordsVelBox:
equilibration_steps = args.equilibration_steps
# equilibrate using the npt mover
npt_mover.n_steps = equilibration_steps
xvb_t = npt_mover.move(xvb_t)
if equilibration_steps > 0:
xvb_t = npt_mover.move(xvb_t)
print("done")

# TBD: cache the minimized and equilibrated initial structure later on to iterate faster.
Expand Down Expand Up @@ -245,7 +254,8 @@ def run_mc_proposals(mover, state: CoordsVelBox, steps: int) -> CoordsVelBox:
writer.write_frame(xvb_t.coords * 10)

# run MD
xvb_t = npt_mover.move(xvb_t)
if args.md_steps_per_batch > 0:
xvb_t = npt_mover.move(xvb_t)

writer.close()
if args.save_last_frame:
Expand Down
Binary file modified tests/data/reference_6_water_targeted.npz
Binary file not shown.
Binary file modified tests/data/reference_6_water_untargeted.npz
Binary file not shown.
72 changes: 66 additions & 6 deletions tests/test_cuda_bd_exchange_mover.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def verify_bias_deletion_moves(

last_conf = x_move
assert bdem.n_proposed() == total_num_proposals
print(f"Accepted { bdem.n_accepted()} of {total_num_proposals} moves")
print(f"Accepted {bdem.n_accepted()} of {total_num_proposals} moves")
assert accepted > 0, "No moves were made, nothing was tested"
if proposals_per_move == 1:
np.testing.assert_allclose(bdem.acceptance_fraction(), accepted / total_num_proposals)
Expand Down Expand Up @@ -438,7 +438,7 @@ def test_bd_exchange_deterministic_moves(proposals_per_move, batch_size, precisi
assert bdem_a.n_accepted() == bdem_b.n_accepted()
assert bdem_a.n_proposed() == bdem_b.n_proposed()

# Moves should be deterministic regardless the number of steps taken per move
# Moves should be deterministic regardless the number of proposals per move
np.testing.assert_array_equal(iterative_moved_coords, batch_moved_coords)


Expand Down Expand Up @@ -511,7 +511,7 @@ def test_bd_exchange_deterministic_batch_moves(proposals_per_move, batch_size, p
assert bdem_a.n_accepted() == bdem_b.n_accepted()
assert bdem_a.n_proposed() == bdem_b.n_proposed()

# Moves should be deterministic regardless the number of steps taken per move
# Moves should be deterministic regardless the number of proposals per move
np.testing.assert_array_equal(iterative_moved_coords, batch_moved_coords)


Expand Down Expand Up @@ -609,6 +609,63 @@ def test_moves_in_a_water_box(
)


@pytest.mark.parametrize("num_particles", [3])
@pytest.mark.parametrize("proposals_per_move", [100, 1000])
@pytest.mark.parametrize("precision", [np.float64, np.float32])
@pytest.mark.parametrize("seed", [2023, 2024, 2025])
def test_compute_incremental_log_weights_match_initial_log_weights_when_recomputed(
num_particles, proposals_per_move, precision, seed
):
"""Verify that the result of computing the weights using `compute_initial_log_weights` and the incremental log weights generated
during proposals are identical.
"""
assert (
proposals_per_move > 1
), "If proposals per move is 1 then this isn't meaningful since the weights won't be incremental"
rng = np.random.default_rng(seed)
cutoff = 1.2
beta = 2.0

box_size = 1.0
box = np.eye(3) * box_size
conf = rng.random((num_particles, 3)) * box_size

params = rng.random((num_particles, 4))
params[:, 3] = 0.0 # Put them in the same plane

group_idxs = [[x] for x in range(num_particles)]

N = conf.shape[0]

klass = custom_ops.BDExchangeMove_f32
if precision == np.float64:
klass = custom_ops.BDExchangeMove_f64

# Test version that makes all proposals in a single move
bdem = klass(
N,
group_idxs,
params,
DEFAULT_TEMP,
beta,
cutoff,
seed,
proposals_per_move,
1,
)

updated_coords, _ = bdem.move(conf, box)
assert not np.all(updated_coords == conf)
assert bdem.n_accepted() >= 1
assert bdem.n_proposed() == proposals_per_move

before_log_weights = bdem.get_before_log_weights()
ref_log_weights = bdem.compute_initial_log_weights(updated_coords, box)
# The before weights of the mover should identically match the weights if recomputed from scratch
diff_idxs = np.argwhere(np.array(before_log_weights) != np.array(ref_log_weights))
np.testing.assert_array_equal(before_log_weights, ref_log_weights, err_msg=f"idxs {diff_idxs} don't match")
proteneer marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"batch_size,samples,box_size",
[
Expand All @@ -619,7 +676,7 @@ def test_moves_in_a_water_box(
)
@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-5, 1e-5), (np.float32, 8e-4, 2e-3)])
@pytest.mark.parametrize("seed", [2023])
def test_compute_incremental_weights(batch_size, samples, box_size, precision, rtol, atol, seed):
def test_compute_incremental_log_weights(batch_size, samples, box_size, precision, rtol, atol, seed):
"""Verify that the incremental weights computed are valid for different collections of rotations/translations"""
proposals_per_move = batch_size # Number doesn't matter here, we aren't calling move
ff = Forcefield.load_default()
Expand Down Expand Up @@ -661,7 +718,7 @@ def test_compute_incremental_weights(batch_size, samples, box_size, precision, r
# Scale the translations
translations = rng.uniform(0, 1, size=(batch_size, 3)) * np.diagonal(box)

test_weight_batches = bdem.compute_incremental_weights(conf, box, selected_mols, quaternions, translations)
test_weight_batches = bdem.compute_incremental_log_weights(conf, box, selected_mols, quaternions, translations)
assert len(test_weight_batches) == batch_size
for test_weights, selected_mol, quat, translation in zip(
test_weight_batches, selected_mols, quaternions, translations
Expand All @@ -679,7 +736,10 @@ def test_compute_incremental_weights(batch_size, samples, box_size, precision, r
moved_conf[mol_idxs] = updated_mol_conf
np.testing.assert_equal(trial_conf, moved_conf)
# Janky re-use of assert_energy_arrays_match which is for energies, but functions for any fixed point
assert_energy_arrays_match(np.array(ref_final_weights), np.array(test_weights), atol=atol, rtol=rtol)
# Slightly reduced threshold to deal with these being weights
assert_energy_arrays_match(
np.array(ref_final_weights), np.array(test_weights), atol=atol, rtol=rtol, threshold=5e6
)


@pytest.fixture(scope="module")
Expand Down
Loading