Skip to content

Commit 9a9fa03

Browse files
committed
Fix state plotting
1 parent 9f65f62 commit 9a9fa03

File tree

5 files changed

+27
-16
lines changed

5 files changed

+27
-16
lines changed

main.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from argparse import ArgumentParser
2-
32
from utils.args import parse_args
43
from systems import System
54
import matplotlib.pyplot as plt
5+
import jax.numpy as jnp
6+
import jax
7+
from flax.training import train_state
8+
import optax
9+
import model.diagonal as diagonal
10+
from model.train import train
611

712
parser = ArgumentParser()
813
parser.add_argument('--out', type=str, default=None, help="Specify a path where the data will be stored.")
@@ -59,18 +64,9 @@
5964
raise NotImplementedError
6065
# system = System.from_forcefield(args.start, args.target)
6166

62-
import jax.numpy as jnp
63-
import jax
64-
from tqdm import trange
65-
from flax.training import train_state
66-
import optax
67-
import model.diagonal as diagonal
68-
from model.train import train
69-
from model import MLPq
70-
71-
N = int(args.T / args.dt)
72-
7367
# You can play around with any model here
68+
# The chosen setup will append a final layer so that the output is mu, sigma, and weights
69+
from model import MLPq
7470
model = MLPq([128, 128, 128])
7571

7672
# TODO: parameterize mixtures, weights, and base_sigma
@@ -113,3 +109,12 @@
113109

114110
key, path_key = jax.random.split(key)
115111
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, args.xi, path_key)
112+
113+
if system.plot:
114+
system.plot(title='Deterministic Paths', trajectories=x_t_det)
115+
plt.show()
116+
plt.clf()
117+
118+
system.plot(title='Stochastic Paths', trajectories=x_t_stoch)
119+
plt.show()
120+
plt.clf()

model/qsetup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float,
4444

4545
for i in trange(N):
4646
for j in range(0, num_paths, BS):
47+
# If the BS does not divide the number of paths, we need to pad the last batch
4748
if j + BS > num_paths:
4849
j_end = num_paths
4950
cur_x_t = jnp.pad(x_t[j:, i], pad_width=((0, BS - (num_paths - j)), (0, 0)))
@@ -56,6 +57,7 @@ def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float,
5657
if key is None:
5758
noise = 0
5859
else:
60+
# For stochastic sampling we compute the noise
5961
key, iter_key = jax.random.split(key)
6062
noise = xi * jax.random.normal(iter_key, shape=(BS, ndim))
6163

systems.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def from_name(cls, name: str) -> Self:
3737
else:
3838
raise ValueError(f"Unknown system: {name}")
3939

40-
plot = partial(toy.plot_energy_surface, U=U, states=zip(['A', 'B'], [A, B]), xlim=xlim, ylim=ylim, alpha=1.0)
40+
plot = partial(toy.plot_energy_surface, U=U, states=list(zip(['A', 'B'], [A, B])), xlim=xlim, ylim=ylim, alpha=1.0)
4141
mass = jnp.array([1.0, 1.0])
4242
return cls(U, A, B, mass, plot)
4343

tps_baseline_mueller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def interpolate_two_points(start, stop, steps):
4848
return interpolation
4949

5050

51-
plot_energy_surface = partial(toy.plot_energy_surface, U=U, states=zip(['A', 'B'], minima_points),
51+
plot_energy_surface = partial(toy.plot_energy_surface, U=U, states=list(zip(['A', 'B'], minima_points)),
5252
xlim=jnp.array((-1.5, 0.9)), ylim=jnp.array((-0.5, 1.7)), alpha=1.0)
5353

5454
if __name__ == '__main__':

utils/toy_plot_helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import matplotlib.pyplot as plt
33

44

5-
def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=150, levels=30, alpha=0.7, radius=0.1):
5+
def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[],
6+
title=None, bins=150, levels=30, alpha=0.7, radius=0.1):
7+
if title:
8+
plt.title(title)
9+
610
x, y = jnp.linspace(xlim[0], xlim[1], bins), jnp.linspace(ylim[0], ylim[1], bins)
711
x, y = jnp.meshgrid(x, y, indexing='ij')
812
z = U(jnp.stack([x, y], -1).reshape(-1, 2)).reshape([bins, bins])
@@ -42,7 +46,7 @@ def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=
4246
plt.scatter(p[0], p[1], marker='*')
4347

4448
for name, pos in states:
45-
pos = pos.reshape(2,)
49+
pos = pos.reshape(2, )
4650
c = plt.Circle(pos, radius=radius, edgecolor='gray', alpha=alpha, facecolor='white', ls='--', lw=0.7, zorder=10)
4751
plt.gca().add_patch(c)
4852
plt.gca().annotate(name, xy=pos, ha="center", va="center", fontsize=14, zorder=11)

0 commit comments

Comments
 (0)