|
2 | 2 | import matplotlib.pyplot as plt |
3 | 3 |
|
4 | 4 |
|
5 | | -def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=150, levels=30, alpha=0.7, radius=0.1): |
| 5 | +def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], |
| 6 | + title=None, bins=150, levels=30, alpha=0.7, radius=0.1): |
| 7 | + if title: |
| 8 | + plt.title(title) |
| 9 | + |
6 | 10 | x, y = jnp.linspace(xlim[0], xlim[1], bins), jnp.linspace(ylim[0], ylim[1], bins) |
7 | 11 | x, y = jnp.meshgrid(x, y, indexing='ij') |
8 | 12 | z = U(jnp.stack([x, y], -1).reshape(-1, 2)).reshape([bins, bins]) |
@@ -42,7 +46,7 @@ def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins= |
42 | 46 | plt.scatter(p[0], p[1], marker='*') |
43 | 47 |
|
44 | 48 | for name, pos in states: |
45 | | - pos = pos.reshape(2,) |
| 49 | + pos = pos.reshape(2, ) |
46 | 50 | c = plt.Circle(pos, radius=radius, edgecolor='gray', alpha=alpha, facecolor='white', ls='--', lw=0.7, zorder=10) |
47 | 51 | plt.gca().add_patch(c) |
48 | 52 | plt.gca().annotate(name, xy=pos, ha="center", va="center", fontsize=14, zorder=11) |
0 commit comments