## Simulation as Optimization: Finding Paths of Least Action with Gradient Descent
Tim Strang and Sam Greydanus | 2023 | MIT License

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch, time

from celluloid import Camera
from IPython.display import HTML
from base64 import b64encode

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from main import * # SimOpt code

In [None]:
def plot_action_stats(ax, S, T, V, S_ode, T_ode, V_ode):
    alpha = .7
    plot_config = [(S, 'k', '$S$', alpha), (T, 'r', '$\sum_i T_i$', alpha),
                   (-np.asarray(V), 'b', '$-\sum_i V_i$', alpha),
                   (S_ode, 'k:', '$S$ (ODE)', 1), (T_ode, 'r:', '$-\sum_i T_i$ (ODE)', 1),
                   (-V_ode, 'b:', '$-\sum_i V_i$ (ODE)', 1)]
    N = len(S)
    for i, (x, fmt, label, alpha) in enumerate(plot_config):
        if i <= 2:
            ax.plot(np.arange(N), x, fmt, alpha=alpha, linewidth=4, label=label)
        else:
            ax.plot([0,N], [x]*2, fmt, alpha=alpha, linewidth=4, label=label)
            
def plot_helper(ax, ax_labels, fontsz, legend=False):
    #plt.title(ax_labels['title'], fontweight="bold")
    if legend:
        plt.legend(ncol=2, fontsize=fontsz['legend'])
    #plt.xlabel(ax_labels['x_label'])
    #plt.ylabel(ax_labels['y_label'])
    #ax.xaxis.label.set_fontsize(fontsz['x_label'])
    #ax.yaxis.label.set_fontsize(fontsz['y_label'])
    #ax.title.set_fontsize(fontsz['title'])
    ax.tick_params(axis='both', length=9, width=3, labelsize=15)
    plt.tight_layout()  # helps clean up plots sometimes
    
def action_plot(info, x_sim, name, fontsz, ax_labels, L_fn, legend=False, fig=None, ylim=(None, None)):
    fig = plt.figure(figsize=(6, 5.5), dpi=DPI)
    ax = fig.add_subplot(111)
    S_ode, T_ode, V_ode = action(torch.tensor(x_sim), L_fn=L_fn, dt=dt)
    plot_action_stats(ax, info['S'], info['T'], info['V'], S_ode.sum(), T_ode.sum(), V_ode.sum())
    plt.ylim(ylim[0], ylim[1])
    path ='./static/{}_action.png'.format(name.lower().replace(' ', ''))
    plot_helper(ax, ax_labels, fontsz, legend=legend)
    plt.show() ; fig.savefig(path)

def unity(x):
    return x

def dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, plt_fn=unity, ylim=(None, None)):
    fig = plt.figure(figsize=(7, 5), dpi=DPI) 
    ax = fig.add_subplot(111)
    ax.plot(t_sim, plt_fn(x_sim), color='purple', ls='-', label='ODE solution', linewidth=2)
    ax.scatter(t_sim, plt_fn(x_sim), color='purple', s=10)
    ax.plot(t_min, plt_fn(xs_min[0]), 'y.-', alpha=.3, label='Initial (random) path')
    for i, xi in enumerate(xs_min):
        label = 'During optimization' if i==10 else None
        ax.plot(t_min, plt_fn(xi), alpha=.3 + .7 * i/(len(xs_min)-1),
                color=plt.cm.viridis(1-i/(len(xs_min)-1)), label=label)
    ax.plot(t_min, plt_fn(xs_min[-1]), 'g.-', label='Final (optimized) path')
    ax.plot(t_min[[0,-1]], plt_fn(xs_min[0])[[0,-1]], 'g+', markersize=17, label='Points held constant')

    plt.ylim(ylim[0], ylim[1])
    
    plot_helper(ax, ax_labels, fontsz, legend = name=='Free body')
    path ='./dynamic/{}.png'.format(name.lower().replace(' ', ''))
    plt.show() ; fig.savefig(path)
    

DPI=300

## Free body

In [None]:
dt = 0.25 ; N = 1 ; steps = 60
t_sim, x_sim = simulate_freebody(dt=dt, steps=steps)
init_path = PerturbedPath(x_sim, N=N, coords=1, sigma=1.5e0, zero_basepath=True) # [time, N*2]
t_min, path, xs_min, info = minimize_action(init_path, steps=550, step_size=1e0, 
                                       L_fn=lagrangian_freebody, dt=dt, opt='adam', verbose=False)

In [None]:
name = 'Free body'
fontsz = {'title': 23, 'x_label': 23, 'y_label': 23, 'legend': 10}
ax_labels = {'title':'Free Body Height vs Time',
             'x_label':'Time (s)', 'y_label':'Height (m)'}
dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, ylim=(-5, 40))

In [None]:
ax_labels = {'title':'Action and associated quantities ({})'.format(name),
                 'x_label':'Optimizer Steps', 'y_label':'J * s'}
fontsz = {'title': 17, 'x_label': 23, 'y_label': 28, 'legend': 12}
fig = action_plot(info, x_sim, name, fontsz, ax_labels, L_fn=lagrangian_freebody, legend=True, ylim=(-25, 25))

In [None]:
path = 'hero_plot'

def simplot(i, x_sim):
    M_sim = int(len(x_sim)/5)
    fig = plt.figure(figsize=(2, 2), dpi=DPI)
    ax=fig.add_subplot(111)
    j = i+1
    plt.scatter(t_sim[:j*M_sim], x_sim[:j*M_sim], color='purple', s=4)
    plt.plot(t_sim[:j*M_sim], x_sim[:j*M_sim], color='purple', linewidth=1)
    plt.xlim(-1, t_sim.max()+1)
    plt.ylim(-5, 40)
    if i != 0:
        plt.xticks([], []) ; plt.yticks([], []) # SJG comment - this approach works better
        #plt.tick_params(axis='y', labelcolor='w')
    
    plt.tight_layout()
    fig.savefig(path + f'/sim_{M_sim * j}.png')

def minplot(i, min_slice):
    fig = plt.figure(figsize=(2, 2), dpi=DPI)
    plt.scatter(t_min, xs_min[int(min_slice[i])], color='g', s=4)
    plt.plot(t_min, xs_min[int(min_slice[i])], 'g', linewidth=1)
    plt.xlim(-1, t_min.max()+1)
    plt.ylim(-2, 40)
    #plt.title(f'Step {min_slice[i]}', fontweight='bold')
    plt.plot(t_min[[0,-1]], xs_min[0].data[[0,-1]], 'g+', markersize=7)
    plt.xticks([0, 5, 10, 15])
    
    if i != 0:
        plt.xticks([], []) ; plt.yticks([], [])
            
    j = i
    while j>0:
        j = j-1
        slc= int(min_slice[j])
        plt.plot(t_min, xs_min[slc], alpha=.3 + .7*j/i, color=plt.cm.viridis(1-j/i))
    plt.tight_layout()
    fig.savefig(path + f'/min_{min_slice[i]}.png')
    

min_slice=[0, 3, 6, 9, -1]
for i in range(5):
    minplot(i, min_slice)
    simplot(i, x_sim)

## Single pendulum

In [None]:
dt = 1 ; N = 1
t_sim, x_sim = simulate_pend(dt=dt)
init_path = PerturbedPath(x_sim, N=N, coords=1, sigma=3.0e-1, zero_basepath=False) # [time, N*2]
init_path.x.data[3:-3] = init_path.x.data[3:-3] * 0.5
t_min, path, xs_min, info = minimize_action(init_path, steps=2000, step_size=2e-2, L_fn=lagrangian_pend,
                                            dt=dt, opt='adam', print_updates=3, e_coeff=1e2, verbose=False)

In [None]:
def pend_height(xs):
    return radial2cartesian_pend(xs)[...,1]

name = 'Pendulum'
fontsz = {'title': 19, 'x_label': 23, 'y_label': 23, 'legend': 10}
ax_labels = {'title':'Pendulum Height vs Time',
             'x_label':'Time (s)', 'y_label':'Height (m)'}

dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, plt_fn=pend_height)

In [None]:
ax_labels = {'title':'Action and associated quantities ({})'.format(name),
                 'x_label':'Optimizer Steps', 'y_label':'J * s'}
fontsz = {'title': 18, 'x_label': 23, 'y_label': 28, 'legend': 10}
action_plot(info, x_sim, name, fontsz, ax_labels, L_fn=lagrangian_pend, ylim=(-200, 200))

## Double pendulum

In [None]:
dt = 0.06 ; N = 2
t_sim, x_sim = simulate_dblpend(dt=dt)
init_path = PerturbedPath(x_sim, N=N, coords=1, sigma=1e0, zero_basepath=False)
t_min, path, xs_min, info = minimize_action(init_path, steps=200, step_size=1e-1, 
                                            L_fn=lagrangian_dblpend, dt=dt, opt='adam', verbose=False)

In [None]:
def dblpend_height(xs):
    return radial2cartesian_dblpend(xs)[:, 1, 1]

name = 'Double Pendulum'
fontsz = {'title': 19, 'x_label': 23, 'y_label': 23, 'legend': 10}
ax_labels = {'title':'Second Pendulum Height vs Time',
             'x_label':'Time (s)', 'y_label':'Height (m)'}

dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, plt_fn=dblpend_height)

In [None]:
ax_labels = {'title':'Action and associated quantities ({})'.format(name),
                 'x_label':'Optimizer Steps', 'y_label':'J * s'}
fontsz = {'title': 16, 'x_label': 23, 'y_label': 28, 'legend': 10}
action_plot(info, x_sim, name, fontsz, ax_labels, L_fn=lagrangian_dblpend, ylim=(0, 4))

## Three body problem

In [None]:
dt = 0.5 ; N = 3
t_sim, x_sim = simulate_3body(dt=dt, stable_config=False)
init_path = PerturbedPath(x_sim, N=N, sigma=3e-2) # [time, N*2]
t_min, path, xs_min, info = minimize_action(init_path, steps=200, step_size=1e-3, L_fn=lagrangian_3body,
                                            dt=dt, opt='adam', print_updates=10, e_coeff=0)

In [None]:
N = x_sim.shape[-2]
k=0
def ball_xcoord(xs, shape=x_sim.shape, k=k):
    return xs.reshape(shape)[:, 0, k]

name = 'Three body'
ax_labels = {'title':'Ball {} X-Coordinate vs. Time'.format(1 + k//2),
             'x_label':'Time (s)', 'y_label':'Position (m)'}
fontsz = {'title': 18, 'x_label': 23, 'y_label': 23, 'legend': 10}
dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, plt_fn=ball_xcoord)

In [None]:
ax_labels = {'title':'Action and associated quantities ({})'.format(name),
                 'x_label':'Optimizer Steps', 'y_label':'J * s'}
fontsz = {'title': 16, 'x_label': 23, 'y_label': 28, 'legend': 10}
action_plot(info, x_sim, name, fontsz, ax_labels, L_fn=lagrangian_3body, ylim=(-.0001, .001))

In [None]:
def plot_3body(x, do_bodies=False, colors=None, fig=None, fmt='-', **kwargs):
    if fig is None:
        plt.figure(figsize=[4,4], dpi=80)
    if not colors:
        colors = ['red', 'blue', 'green']
    for i, c in enumerate(colors):
        plt.plot(x[:,i,0], x[:,i,1], fmt, color=c, **kwargs)
        if do_bodies:
            plt.plot(x[-1,i,0], x[-1,i,1], '.', color=c, markersize=12)
    plt.xlim(0.2,0.85) ; plt.ylim(0.2,0.85)
    #plt.xlim(0,1) ; plt.ylim(0,1)

fig = plt.figure(figsize=[4,4], dpi=80)
colors = ['black', 'black', 'black']
plot_3body(x_sim, fig=fig, do_bodies=True, fmt='--', alpha=0.7)
plot_3body(xs_min[0].reshape(-1,3,2), fig=fig, do_bodies=True, fmt='-', alpha=0.3)
plot_3body(xs_min[-1].reshape(-1,3,2), fig=fig, do_bodies=True, fmt='-', alpha=1)

plt.tick_params(axis='both', length=9, width=3, labelsize=15)
fig.savefig('dynamic/3body_2d.png')

## Gas simulation

In [None]:
dt = 0.5 ; N = 50
t_sim, x_sim = simulate_gas(dt=dt, N=N)
init_path = PerturbedPath(x_sim, N=N, sigma=1e-2)
t_min, path, xs_min, info = minimize_action(init_path, steps=500, step_size=1e1,
                                       L_fn=lagrangian_gas, dt=dt, opt='sgd')

In [None]:
name = 'Gas'
N = x_sim.shape[-2]
k = 30

def gas_xcoord(xs, N=N, k=k):
    return xs.reshape(-1,N*2)[...,k]

ax_labels = {'title':'Ball {} X-Coordinate vs. Time'.format(1 + k//2),
             'x_label':'Time (s)', 'y_label':'Position (m)'}
fontsz = {'title': 18, 'x_label': 23, 'y_label': 23, 'legend': 10}

dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, plt_fn=gas_xcoord)

In [None]:
ax_labels = {'title':'Action and associated quantities ({})'.format(name),
                 'x_label':'Optimizer Steps', 'y_label':'J * s'}
fontsz = {'title': 16, 'x_label': 23, 'y_label': 28, 'legend': 10}
fig = action_plot(info, x_sim, name, fontsz, ax_labels, L_fn=lagrangian_3body)

In [None]:
def plot_particle(x, i, do_bodies=False, colors=None, fig=None, ls='-', color='k', **kwargs):
    if fig is None:
        plt.figure(figsize=[4,4], dpi=80)
    plt.plot(x[:, i, 0], x[:, i, 1], ls=ls, color=color, linewidth=2, **kwargs)
    if do_bodies:
        plt.plot(x[-1,i,0], x[-1,i,1], '.', color='b', markersize=20)
        plt.plot(x[0,i,0], x[0,i,1], '+', color='b', markersize=15, mew=2)
    plt.xlim(0.2,0.85) ; plt.ylim(0.2,0.85)

N = x_sim.shape[-2]
obj = 31
fig = plt.figure(figsize=[4,4], dpi=80)
plot_particle(x_sim, obj, fig=fig, do_bodies=False, ls='--', color='purple', alpha=0.7)
plot_particle(xs_min[0].reshape(-1,N,2), obj, fig=fig, do_bodies=False, color='y', alpha=0.7)
plot_particle(xs_min[-1].reshape(-1,N,2), obj, fig=fig, do_bodies=True, color='g', alpha=1)
plt.xlim(.5, .8)
plt.ylim(.3, .6)

plt.tick_params(axis='both', length=9, width=3, labelsize=15)
fig.savefig('dynamic/gas_p31.png')

## Ephemeris dataset and simulation

In [None]:
planets = ['sun', 'mercury', 'venus', 'earth', 'mars']
data_dir = './data/'
df = process_raw_ephemeris(planets, data_dir, last_n_days=365) #365

t_sim, x_sim = simulate_planets(df, planets)
plot_planets(df, planets)

colors = get_planet_colors()
for i, (planet, coords) in enumerate(zip(planets, x_sim.transpose(1,2,0))):
    x, y = coords
    plt.plot(x, y, ':', alpha=0.5, color=colors[planet], label=planets[i] + ' (sim)')
    plt.plot(x[0], y[0], '+', color=colors[planet])
    plt.plot(x[-1], y[-1], 'x', color=colors[planet])
plt.axis('equal')
plt.legend(fontsize=6,  loc='upper right', ncol=2) ; plt.show()

In [None]:
dt = 24*60*60 ; N = len(planets)
df = process_raw_ephemeris(planets, data_dir, last_n_days=365)
t_sim, x_sim = simulate_planets(df, planets, dt=dt)
init_path = PerturbedPath(x_sim, N=N, sigma=2e10, is_ephemeris=True)

L_planets = partial(lagrangian_planets, masses=get_masses(planets))

t_min, path, xs_min, info = minimize_action(init_path, steps=500, step_size=1e9,
                                       L_fn=L_planets, dt=dt, opt='adam')

In [None]:
def earth_ycoord(xs, N=N):
    return xs.reshape(-1,N,2)[:, 3, 1]

ax_labels = {'title':'Earth Y-Coordinate',
             'x_label':'Time (hr)', 'y_label':'Position (m)'}
fontsz = {'title': 18, 'x_label': 23, 'y_label': 23, 'legend': 10}
name = 'Ephemeris'

dynamic_plot(t_sim, x_sim, t_min, xs_min, name, fontsz, ax_labels, plt_fn=earth_ycoord)

In [None]:
ax_labels = {'title':'Action and associated quantities ({})'.format(name),
                 'x_label':'Optimizer Steps', 'y_label':'J * s'}
fontsz = {'title': 16, 'x_label': 23, 'y_label': 28, 'legend': 10}
action_plot(info, x_sim, name, fontsz, ax_labels, L_fn=lagrangian_3body, ylim=(-1e33, .8e34))

In [None]:
fig = plt.figure(figsize=[5,5], dpi=140)

xsi = xs_min[0].detach().numpy().reshape(-1,N,2)
xsf = xs_min[-1].detach().numpy().reshape(-1,N,2)
sim = x_sim.reshape(-1,N,2)

plt.plot(sim[:, 0, 0], sim[:, 0, 1], 'x', color='b', markersize=4)
plt.plot(sim[:, 3, 0], sim[:, 3, 1], '.-', color='purple', markersize=4)
plt.plot(xsi[:, 3, 0], xsi[:, 3, 1], '-.', color='y', alpha=0.3)
plt.plot(xsf[:, 3, 0], xsf[:, 3, 1], '.-', color='g', alpha=1, markersize=5)
plt.plot(xsi[[0, -1], 3, 0], xsi[[0, -1], 3, 1], '+', color='b', markersize=10, mew=2)

plt.tick_params(axis='both', length=9, width=3, labelsize=15)
    

plt.axis('equal')
plt.show()
fig.savefig('dynamic/earth_xy.png')

# 

In [None]:
fig = plt.figure(figsize=[5,5], dpi=140)
plot_planets(df, planets, fig=fig)
colors = get_planet_colors()

xs = xs_min[0].detach().numpy().reshape(-1,N,2)
for i, (planet, coords) in enumerate(zip(planets, xs.transpose(1,2,0))):
    x, y = coords
    plt.plot(x, y, '.', alpha=0.3, color=colors[planet], label=planets[i] + ' (init)')
    plt.plot(x[0], y[0], '+', color=colors[planet])
    plt.plot(x[-1], y[-1], 'x', color=colors[planet])
    
xs = xs_min[-1].detach().numpy().reshape(-1,N,2)
for i, (planet, coords) in enumerate(zip(planets, xs.transpose(1,2,0))):
    x, y = coords
    plt.plot(x, y, ':', alpha=0.5, color=colors[planet], label=planets[i] + ' (path)')

plt.axis('equal')
plt.legend(fontsize=6,  loc='upper right', ncol=3) ; plt.show()