@@ -137,31 +137,37 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
137137 if fixed_length > 0 :
138138 statistics ['fixed_length' ] = fixed_length
139139
140- with tqdm (total = num_paths + warmup , desc = 'warming up' if warmup > 0 else '' ) as pbar :
141- while len (trajectories ) <= num_paths + warmup :
142- statistics ['num_tries' ] += 1
143- if len (trajectories ) > warmup :
144- pbar .set_description ('' )
145-
146- key , traj_idx_key , iter_key , accept_key = jax .random .split (key , 4 )
147- traj_idx = jax .random .randint (traj_idx_key , (1 ,), warmup + 1 , len (trajectories ))[0 ]
148- # during warmup, we want an iterative scheme
149- traj_idx = traj_idx if traj_idx < len (trajectories ) else - 1
150-
151- found , new_trajectory , new_velocities = proposal (system , trajectories [traj_idx ], fixed_length , iter_key )
152- statistics ['num_force_evaluations' ] += len (new_trajectory ) - 1
153-
154- if not found :
155- continue
156-
157- ratio = len (trajectories [- 1 ]) / len (new_trajectory )
158- # The first trajectory might have a very unreasonable length, so we skip it
159- if len (trajectories ) == 1 or jax .random .uniform (accept_key , shape = (1 ,)) < ratio :
160- trajectories .append (new_trajectory )
161- velocities .append (new_velocities )
162- pbar .update (1 )
163- else :
164- statistics ['num_metropolis_rejected' ] += 1
140+ try :
141+ with tqdm (total = num_paths + warmup , desc = 'warming up' if warmup > 0 else '' ) as pbar :
142+ while len (trajectories ) <= num_paths + warmup :
143+ statistics ['num_tries' ] += 1
144+ if len (trajectories ) > warmup :
145+ pbar .set_description ('' )
146+
147+ key , traj_idx_key , iter_key , accept_key = jax .random .split (key , 4 )
148+ traj_idx = jax .random .randint (traj_idx_key , (1 ,), warmup + 1 , len (trajectories ))[0 ]
149+ # during warmup, we want an iterative scheme
150+ traj_idx = traj_idx if traj_idx < len (trajectories ) else - 1
151+
152+ found , new_trajectory , new_velocities = proposal (system , trajectories [traj_idx ], fixed_length , iter_key )
153+ statistics ['num_force_evaluations' ] += len (new_trajectory ) - 1
154+
155+ if not found :
156+ continue
157+
158+ ratio = len (trajectories [- 1 ]) / len (new_trajectory )
159+ # The first trajectory might have a very unreasonable length, so we skip it
160+ if len (trajectories ) == 1 or jax .random .uniform (accept_key , shape = (1 ,)) < ratio :
161+ trajectories .append (new_trajectory )
162+ velocities .append (new_velocities )
163+ pbar .update (1 )
164+ else :
165+ statistics ['num_metropolis_rejected' ] += 1
166+ except KeyboardInterrupt :
167+ print ('SIGINT received, stopping early' )
168+ # Fix in case we stop when adding a trajectory
169+ if len (trajectories ) > len (velocities ):
170+ velocities .append (new_velocities )
165171
166172 return trajectories [warmup + 1 :], velocities [warmup :], statistics
167173
0 commit comments