Skip to content

Commit

Permalink
Minor cleanup to verlet integrator (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Mar 2, 2018
1 parent bfd0e54 commit bc7e984
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions pyro/ops/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def velocity_verlet(z, r, potential_fn, step_size, num_steps=1):
:param int num_steps: number of discrete time steps over which to integrate.
:return tuple (z_next, r_next): final position and momenta, having same types as (z, r).
"""
z_next = {key: val.data.clone() for key, val in z.items()}
r_next = {key: val.data.clone() for key, val in r.items()}
z_next = z.copy()
r_next = r.copy()
grads, _ = _grad(potential_fn, z_next)

for _ in range(num_steps):
Expand All @@ -33,8 +33,6 @@ def velocity_verlet(z, r, potential_fn, step_size, num_steps=1):
for site_name in r_next:
# r(n+1)
r_next[site_name] = r_next[site_name] + 0.5 * step_size * (-grads[site_name])
z_next = {key: Variable(val) for key, val in z_next.items()}
r_next = {key: Variable(val) for key, val in r_next.items()}
return z_next, r_next


Expand All @@ -47,8 +45,8 @@ def single_step_velocity_verlet(z, r, potential_fn, step_size, z_grads=None):
:return tuple (z_next, r_next, z_grads, potential_energy): next position and momenta,
together with the potential energy and its gradient w.r.t. ``z_next``.
"""
z_next = {key: val.data.clone() for key, val in z.items()}
r_next = {key: val.data.clone() for key, val in r.items()}
z_next = z.copy()
r_next = r.copy()
grads = _grad(potential_fn, z_next)[0] if z_grads is None else z_grads

for site_name in z_next:
Expand All @@ -57,8 +55,6 @@ def single_step_velocity_verlet(z, r, potential_fn, step_size, z_grads=None):
grads, potential_energy = _grad(potential_fn, z_next)
for site_name in r_next:
r_next[site_name] = r_next[site_name] + 0.5 * step_size * (-grads[site_name])
z_next = {key: Variable(val) for key, val in z_next.items()}
r_next = {key: Variable(val) for key, val in r_next.items()}
return z_next, r_next, grads, potential_energy


Expand Down

0 comments on commit bc7e984

Please sign in to comment.