22import jax .numpy as jnp
33import jax
44from eval .path_metrics import plot_path_energy
5+ from systems import System
56from tps .paths import decorrelated
6- from tps_baseline_mueller import U , dUdx_fn
77from scipy .optimize import minimize
88import matplotlib .pyplot as plt
99import os
1515T = 275e-4
1616N = int (T / dt )
1717
18+ system = System .from_name ('mueller_brown' , float ('inf' ))
19+
1820minima_points = jnp .array ([[- 0.55828035 , 1.44169 ],
1921 [- 0.05004308 , 0.46666032 ],
2022 [0.62361133 , 0.02804632 ]])
@@ -27,8 +29,17 @@ def load(path):
2729
2830@jax .jit
2931def log_path_likelihood (path ):
30- rand = path [1 :] - path [:- 1 ] + dt * dUdx_fn (path [:- 1 ])
31- return (- U (path [0 ]) / kbT ).sum () + jax .scipy .stats .norm .logpdf (rand , scale = jnp .sqrt (dt ) * xi ).sum ()
32+ rand = path [1 :] - path [:- 1 ] + dt * system .dUdx (path [:- 1 ])
33+ return (- system .U (path [0 ]) / kbT ).sum () + jax .scipy .stats .norm .logpdf (rand , scale = jnp .sqrt (dt ) * xi ).sum ()
34+
35+
36+ def plot_hist (system , paths , trajectories_to_plot , seed = 1 ):
37+ system .plot (trajectories = paths )
38+ colors = plt .rcParams ['axes.prop_cycle' ].by_key ()['color' ]
39+ idx = jax .random .permutation (jax .random .PRNGKey (seed ), len (paths ))[:trajectories_to_plot ]
40+ for i , c in zip (idx , colors [1 :]):
41+ cur_paths = jnp .array (paths [i ])
42+ plt .plot (cur_paths [:, 0 ].T , cur_paths [:, 1 ].T , c = c )
3243
3344
3445if __name__ == '__main__' :
@@ -43,19 +54,29 @@ def log_path_likelihood(path):
4354 ('var-doobs' , './out/var_doobs/mueller/paths.npy' , 0 ),
4455 ]
4556
46- global_minimum_energy = U (minima_points [ 0 ] )
57+ global_minimum_energy = min ( system . U (minima_points ) )
4758 for point in minima_points :
48- global_minimum_energy = min (global_minimum_energy , minimize (U , point ).fun )
59+ global_minimum_energy = min (global_minimum_energy , minimize (system . U , point ).fun )
4960 print ("Global minimum energy" , global_minimum_energy )
5061
5162 all_paths = [(name , load (path )[warmup :],) for name , path , warmup in all_paths ]
5263 [print (name , len (path )) for name , path in all_paths ]
5364
65+ for name , paths in all_paths :
66+ # for this plot we limit ourselves to 250 paths
67+ plot_hist (system , paths [:250 ], 2 )
68+ plt .savefig (f'{ savedir } /{ name } -histogram.pdf' , bbox_inches = 'tight' )
69+ plt .show ()
70+
71+ plot_hist (system , decorrelated (paths )[:250 ], 2 )
72+ plt .savefig (f'{ savedir } /{ name } -decorrelated-histogram.pdf' , bbox_inches = 'tight' )
73+ plt .show ()
74+
5475 for name , paths in all_paths :
5576 print (name , 'decorrelated trajectories:' , jnp .round (100 * len (decorrelated (paths )) / len (paths ), 2 ), '%' )
5677
5778 for name , paths in all_paths :
58- max_energy = plot_path_energy (paths , U , add = - global_minimum_energy , label = name ) + global_minimum_energy
79+ max_energy = plot_path_energy (paths , system . U , add = - global_minimum_energy , label = name ) + global_minimum_energy
5980 print (name , 'max energy mean:' , jnp .round (jnp .mean (max_energy ), 2 ), 'std:' , jnp .round (jnp .std (max_energy ), 2 ))
6081 print (name , 'min max energy: ' , jnp .round (jnp .min (max_energy ), 2 ))
6182
@@ -65,7 +86,7 @@ def log_path_likelihood(path):
6586 plt .show ()
6687
6788 for name , paths in all_paths :
68- plot_path_energy (paths , U , add = - global_minimum_energy , reduce = jnp .median , label = name )
89+ plot_path_energy (paths , system . U , add = - global_minimum_energy , reduce = jnp .median , label = name )
6990
7091 plt .legend ()
7192 plt .ylabel ('Median energy' )
0 commit comments