## Simulation as Optimization: Finding Paths of Least Action with Gradient Descent
Tim Strang and Sam Greydanus | 2023 | MIT License

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch, time

from celluloid import Camera
from IPython.display import HTML
from base64 import b64encode

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from main import * # SimOpt code

In [None]:
# def action(x, L_fn, dt):
#     xdot = (x[1:] - x[:-1]) / dt
#     xdot = torch.cat([xdot, xdot[-1:]], axis=0)
#     return L_fn(x, xdot).sum()

# def minimize_action(path, steps, step_size, L_fn, dt, opt='sgd', print_every=15):
#     t = np.linspace(0, len(path.x)-1, len(path.x)) * dt
#     optimizer = torch.optim.SGD(path.parameters(), lr=step_size, momentum=0) if opt=='sgd' else \
#                 torch.optim.Adam(path.parameters(), lr=step_size)
#     xs = [path.x.clone().data]
#     t0 = time.time()
#     for i in range(steps):
#         S = action(path.x, L_fn, dt)
#         S.backward() ; path.x.grad.data[[0,-1]] *= 0
#         optimizer.step() ; path.zero_grad()

#         if i % (steps//print_every) == 0:
#             xs.append(path.x.clone().data)
#             print('step={:04d}, S={:.3e} J*s, dt={:.1f}s'.format(i, S.item(), time.time()-t0))
#             t0 = time.time()
#     return t, path, xs

## Free body

In [None]:
t_sim, x_sim = simulate_freebody(dt=0.25, steps=60)
plt.title('Free body')
plt.plot(t_sim, x_sim, 'k.-')
plt.show()

In [None]:
dt = 0.25 ; N = 1 ; steps = 60
t_sim, x_sim = simulate_freebody(dt=dt, steps=steps)
init_path = PerturbedPath(x_sim, N=N, coords=1, sigma=1.5e0, zero_basepath=True) # [time, N*2]
t_min, path, xs_min, info = minimize_action(init_path, steps=500, step_size=1e0, 
                                       L_fn=lagrangian_freebody, dt=dt, opt='adam')

In [None]:
plt.figure(dpi=90)
plt.title('Particle in freefall')
plt.plot(t_sim, x_sim, 'r-', label='ODE solution')

plt.plot(t_min, xs_min[0], 'y.-', label='Initial (random) path')
for i, xi in enumerate(xs_min):
    label = 'During optimization' if i==10 else None
    plt.plot(t_min, xi, alpha=0.3, color=plt.cm.viridis( 1-i/(len(xs_min)-1) ), label=label)
plt.plot(t_min, xs_min[-1], 'b.-', label='Final (optimized) path')
plt.plot(t_min[[0,-1]], xs_min[0].data[[0,-1]], 'b+', markersize=15, label='Points held constant')

plt.ylim(-5, 40)
plt.xlabel('Time (s)') ; plt.ylabel('Height (m)') ; plt.legend(fontsize=8, ncol=3)
plt.tight_layout() ; plt.show()

In [None]:
def action_plot(ax, S, T, V, S_ode, T_ode, V_ode):
    plot_config = [(S, 'k', '$S$'), (T, 'r', '$\sum_i T_i$'),
         (V, 'b', '$-\sum_i V_i$'), (S_ode, 'k--', '$S$ (ODE)'),
         (T_ode, 'r--', '$-\sum_i T_i$ (ODE)'), (V_ode, 'b--', '$-\sum_i V_i$ (ODE)')]
    N = len(S)
    for i, (x, fmt, label) in enumerate(plot_config):
        if i <= 2:
            ax.plot(np.arange(N), x, fmt, linewidth=4, label=label)
        else:
            ax.plot([0,N], [x]*2, fmt, linewidth=4, label=label)
            
def plot_help(ax, ax_labels, steps=None):
    plt.title(ax_labels['title'], fontweight="bold")
    plt.legend(ncol=2)
    plt.xlabel(ax_labels['x_label'])
    plt.ylabel(ax_labels['y_label'])
    for item in [ax.xaxis.label, ax.yaxis.label]:
        item.set_fontsize(23)
    ax.title.set_fontsize(19)
#     if steps:
#         plt.xticks(np.linspace(0, steps, num=5).astype(int))
    ax.tick_params(axis='both', length=9, width=3, labelsize=15)

In [None]:
fig = plt.figure(figsize=(8.3333, 6.25), dpi=80)
ax = fig.add_subplot(111)

S_ode, T_ode, V_ode = action(torch.from_numpy(x_sim), L_fn=lagrangian_freebody, dt=dt)

ax_labels = {'title':'SimOpt Freebody Action Over Optimization',
             'x_label':'Optimizer Steps', 'y_label':'J * s'}

action_plot(ax, info['S'], info['T'], info['V'], S_ode, T_ode, V_ode)
plot_help(ax, ax_labels, steps=steps)
plt.show()
# fig.savefig(project_dir + 'freebody_action.pdf')

## Single pendulum

In [None]:
t_sim, x_sim = simulate_pend(dt=1)
plt.title('Pendulum')
plt.plot(t_sim, x_sim, 'k.-')
plt.show()

In [None]:
dt = 1 ; N = 1
t_sim, x_sim = simulate_pend(dt=dt)

init_path = PerturbedPath(x_sim, N=N, coords=1, sigma=1.5e0, zero_basepath=True) # [time, N*2]
t_min, path, xs_min = minimize_action(init_path, steps=1000, step_size=1e0, 
                                       L_fn=lagrangian_pend, dt=dt, opt='adam')

plt.title('Pendulum')
plt.plot(t_sim, x_sim, 'k.-', label='Simulator')
plt.plot(t_min, xs_min[-1], 'b.-', label='Path of least action')
plt.legend()
plt.tight_layout() ; plt.show()

## Double pendulum

In [None]:
t_sim, x_sim = simulate_dblpend(dt=0.06)
plt.title('Double pendulum')
plt.plot(t_sim, x_sim, 'k.-')
plt.show()

make_video(radial2cartesian(x_sim), path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
dt = 0.06 ; N = 2
t_sim, x_sim = simulate_dblpend(dt=dt)

init_path = PerturbedPath(x_sim, N=N, coords=1, sigma=1e0, zero_basepath=False) # [time, N*2]
t_min, path, xs_min = minimize_action(init_path, steps=100, step_size=1e-1, 
                                       L_fn=lagrangian_dblpend, dt=dt, opt='adam')


plt.title('Double Pendulum (theta 2)')
plt.plot(t_sim, x_sim[:,1], 'k.-', label='Simulator')
plt.plot(t_min, xs_min[0][:,1], 'y-', label='Initial path')
plt.plot(t_min, xs_min[-1][:,1], 'b.-', label='Path of least action')
plt.legend()
plt.tight_layout() ; plt.show()

make_video(radial2cartesian(xs_min[-1]), path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

## Three body problem

In [None]:
t, x = simulate_3body()
make_video(x, path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
dt = 0.5 ; N = 3

t_sim, x_sim = simulate_3body(dt=dt)
init_path = PerturbedPath(x_sim, N=N, sigma=2e-2) # [time, N*2]
t_min, path, xs_min = minimize_action(init_path, steps=500, step_size=1e1,
                                       L_fn=lagrangian_3body, dt=dt, opt='sgd')

In [None]:
N = x_sim.shape[-2]
xs_before = xs_min[0].detach().numpy().reshape(-1,N,2)
xs_after = xs_min[-1].detach().numpy().reshape(-1,N,2)

plt.figure(dpi=100) ; k = 1
plt.title('Ball {} horiz. velocity vs. time'.format(1 + k//2))
plt.plot((xs_before[1:] - xs_before[:-1]).reshape(-1,N*2)[...,k], '.-', label='Initial path')
plt.plot((xs_after[1:] - xs_after[:-1]).reshape(-1,N*2)[...,k], '.-', label='Minimum action')
plt.plot((x_sim[1:] - x_sim[:-1]).reshape(-1,N*2)[...,k], 'k-', label='Simulator')
plt.legend()
plt.show()

In [None]:
xs = xs_min[0].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
xs = xs_min[-1].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=30, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

## Gas simulation

In [None]:
t, x = simulate_gas(dt=.5, N=50)

make_video(x, path='sim.mp4', interval=30)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
dt = 0.5 ; N = 50
t_sim, x_sim = simulate_gas(dt=dt, N=N)
init_path = PerturbedPath(x_sim, N=N, sigma=1e-2) # [time, N*2]
t_min, path, xs_min = minimize_action(init_path, steps=500, step_size=1e1,
                                       L_fn=lagrangian_gas, dt=dt, opt='sgd')

In [None]:
N = x_sim.shape[-2]
xs_before = xs_min[0].detach().numpy().reshape(-1,N,2)
xs_after = xs_min[-1].detach().numpy().reshape(-1,N,2)

k = 25
plt.figure(dpi=100)
plt.title('Ball {} horiz. velocity vs. time'.format(1 + k//2))
plt.plot((xs_before[1:] - xs_before[:-1]).reshape(-1,N*2)[...,k], '.-', label='Initial path')
plt.plot((xs_after[1:] - xs_after[:-1]).reshape(-1,N*2)[...,k], '.-', label='Minimum action')
plt.plot((x_sim[1:] - x_sim[:-1]).reshape(-1,N*2)[...,k], 'k-', label='Simulator')
plt.legend()
plt.show()

In [None]:
xs = xs_min[0].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=30, ms=10)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

## Ephemeris dataset and simulation

In [None]:
planets = ['sun', 'mercury', 'venus', 'earth', 'mars']
data_dir = './data/'
df = process_raw_ephemeris(planets, data_dir, last_n_days=365) #365

t_sim, x_sim = simulate_planets(df, planets)
plot_planets(df, planets)

colors = get_planet_colors()
for i, (planet, coords) in enumerate(zip(planets, x_sim.transpose(1,2,0))):
    x, y = coords
    plt.plot(x, y, ':', alpha=0.5, color=colors[planet], label=planets[i] + ' (sim)')
    plt.plot(x[0], y[0], '+', color=colors[planet])
    plt.plot(x[-1], y[-1], 'x', color=colors[planet])
plt.axis('equal')
plt.legend(fontsize=6,  loc='upper right', ncol=2) ; plt.show()

In [None]:
dt = 24*60*60 ; N = len(planets)
df = process_raw_ephemeris(planets, data_dir, last_n_days=365)
t_sim, x_sim = simulate_planets(df, planets, dt=dt)
init_path = PerturbedPath(x_sim, N=N, sigma=2e10, is_ephemeris=True) # [time, N*2]

L_planets = partial(lagrangian_planets, masses=get_masses(planets))

t_min, path, xs_min = minimize_action(init_path, steps=500, step_size=1e9,
                                       L_fn=L_planets, dt=dt, opt='adam')

In [None]:
plt.figure(figsize=[5,3], dpi=120)
plt.title('Earth y coordinate')
xs_sim = init_path.x_true
xs_init = xs_min[0].detach().numpy().reshape(-1,N,2)
xs_final = xs_min[-1].detach().numpy().reshape(-1,N,2)
plt.plot(xs_sim[:,2,1], '--', label='sim')
plt.plot(xs_init[:,2,1], alpha=0.5, label='init')
plt.plot(xs_final[:,2,1], alpha=0.5, label='final')
plt.legend()

In [None]:
fig = plt.figure(figsize=[5,5], dpi=140)
plot_planets(df, planets, fig=fig)
colors = get_planet_colors()

xs = xs_min[0].detach().numpy().reshape(-1,N,2)
for i, (planet, coords) in enumerate(zip(planets, xs.transpose(1,2,0))):
    x, y = coords
    plt.plot(x, y, '.', alpha=0.3, color=colors[planet], label=planets[i] + ' (init)')
    plt.plot(x[0], y[0], '+', color=colors[planet])
    plt.plot(x[-1], y[-1], 'x', color=colors[planet])
    
xs = xs_min[-1].detach().numpy().reshape(-1,N,2)
for i, (planet, coords) in enumerate(zip(planets, xs.transpose(1,2,0))):
    x, y = coords
    plt.plot(x, y, ':', alpha=0.5, color=colors[planet], label=planets[i] + ' (path)')

plt.axis('equal')
plt.legend(fontsize=6,  loc='upper right', ncol=3) ; plt.show()