Skip to content

Commit 87f8ea3

Browse files
committed
Add double_well to toys
1 parent 4198103 commit 87f8ea3

File tree

6 files changed

+43
-2
lines changed

6 files changed

+43
-2
lines changed

configs/toy/double_well.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
save_dir: ./out/toy/double_well
2+
3+
test_system: double_well
4+
ode: first_order
5+
parameterization: diagonal
6+
T: 1.0
7+
xi: 1.0
8+
gamma: 1.0
9+
10+
num_gaussians: 1
11+
trainable_weights: False
12+
base_sigma: 0.1
13+
14+
epochs: 2500
15+
BS: 512
16+
17+
num_paths: 100
18+
dt: 5e-4

configs/toy/double_well_hard.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
save_dir: ./out/toy/double_well_hard
2+
3+
test_system: double_well_hard
4+
ode: first_order
5+
parameterization: diagonal
6+
T: 1.0
7+
xi: 1.0
8+
gamma: 1.0
9+
10+
num_gaussians: 1
11+
trainable_weights: False
12+
base_sigma: 0.1
13+
14+
epochs: 2500
15+
BS: 512
16+
17+
num_paths: 100
18+
dt: 5e-4

plot_test.py

Whitespace-only changes.

potentials.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def U_mueller_brown(xs, beta=1.0):
3333
return beta * (e1 + e2 + e3 + e4)
3434

3535

36-
double_well = (U_double_well,)
37-
double_well_hard = (U_double_well_hard,)
36+
double_well = (U_double_well, jnp.array([-jnp.sqrt(2), 0]), jnp.array([jnp.sqrt(2), 0]))
37+
double_well_hard = (U_double_well_hard, jnp.array([-3, 0]), jnp.array([3, 0]))
3838
double_well_dual_channel = (U_double_well_dual_channel, jnp.array([-0.5, 0]), jnp.array([0.5, 0]))
3939
mueller_brown = (U_mueller_brown, jnp.array([-0.55828035, 1.44169]), jnp.array([0.62361133, 0.02804632]))

systems.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ def __init__(self, U: Callable[[ArrayLike], ArrayLike], A: ArrayLike, B: ArrayLi
3434
def from_name(cls, name: str, force_clip: float) -> Self:
3535
if name == 'double_well':
3636
U, A, B = potentials.double_well
37+
xlim = jnp.array((-2.0, 2.0))
38+
ylim = jnp.array((-2.5, 2.5))
3739
elif name == 'double_well_hard':
3840
U, A, B = potentials.double_well_hard
41+
xlim = jnp.array((-6.0, 6.0))
42+
ylim = jnp.array((-4.0, 4.0))
3943
elif name == 'double_well_dual_channel':
4044
U, A, B = potentials.double_well_dual_channel
4145
xlim = jnp.array((-1.0, 1.0))

tps_baseline_mueller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def step(_x, _key):
8383
with open(f'{save_dir}/stats-{name}.json', 'r') as fp:
8484
statistics = json.load(fp)
8585
else:
86+
print('Generating paths for', name)
8687
paths, statistics = tps1.mcmc_shooting(tps_config, method, initial_trajectory, num_paths,
8788
jax.random.PRNGKey(1), warmup=0, fixed_length=N)
8889

0 commit comments

Comments
 (0)