## Simulation as Optimization: reuse code

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch, time

from celluloid import Camera
from IPython.display import HTML
from base64 import b64encode

In [None]:
def make_video(xs, path, interval=60, ms=10, **kwargs): # xs: [time, N, 2]
    fig = plt.gcf() ; fig.set_dpi(100) ; fig.set_size_inches(3, 3)
    camera = Camera(fig)
    for i in range(xs.shape[0]):
        plt.plot(xs[i][...,0], xs[i][...,1], 'k.', markersize=ms)
        plt.axis('equal') ; plt.xlim(0,1) ; plt.ylim(0,1)
        plt.xticks([], []); plt.yticks([], [])
        camera.snap()
    anim = camera.animate(blit=True, interval=interval, **kwargs)
    anim.save(path) ; plt.close()

## Get a baseline simulation working

In [None]:
def V_gas(xs, eps=1e-6, overlap_radius=0.05, scale_coeff=1e-6): # 1e-5
    if len(xs.shape) > 2:
        return sum([V_gas(_xs, overlap_radius, scale_coeff) for _xs in xs]) # broadcast
    else:
        
        dist_matrix = ((xs[:,0:1] - xs[:,0:1].T).pow(2) + (xs[:,1:2] - xs[:,1:2].T).pow(2) + eps).sqrt()
        dists = dist_matrix[torch.triu_indices(xs.shape[0], xs.shape[0], 1).split(1)]
        # potentials  = (dists > 1-overlap_radius) * (5e2*(overlap_radius - (1-dists)) + 1/overlap_radius) # cap
        potentials = (dists > 0.5) * (dists < 1-overlap_radius) * 1/(1-dists + eps)  # 1/(1-r) (wraparound)
        potentials += (dists > overlap_radius)* (dists < 0.5) * 1/(dists + eps)  # 1/r
        potentials += (dists < overlap_radius) * (5e2*(overlap_radius - dists) + 1/overlap_radius)  # cap
        return -potentials.sum() * scale_coeff
    
def V_3body(xs, eps=1e-6, overlap_radius=0.05, scale_coeff=1e-4): # 1e-5
    if len(xs.shape) > 2:
        return sum([V_3body(_xs, overlap_radius, scale_coeff) for _xs in xs]) # broadcast
    else:
        dist_matrix = ((xs[:,0:1] - xs[:,0:1].T).pow(2) + (xs[:,1:2] - xs[:,1:2].T).pow(2) + eps).sqrt()
        dists = dist_matrix[torch.triu_indices(xs.shape[0], xs.shape[0], 1).split(1)]
        potentials =  (dists > overlap_radius) * 1/(dists + eps)  # 1/r^2
        potentials += (dists < overlap_radius) * (5e2*(overlap_radius - dists) + 1/overlap_radius)
        return potentials.sum() * scale_coeff
    
def forces(xs, potential_fn, **kwargs):
    xs.requires_grad = True
    return torch.autograd.grad(potential_fn(xs), xs)[0]

In [None]:
def solve_ode_euler(x0, x1, dt, steps=100, box_width=1, potential_fn=None):
    xs = [x0, x1]
    ts = [0, dt]
    v = (x1 - x0) / dt
    x = xs[-1]
    for i in range(steps-2):
        a = forces(torch.tensor(x), potential_fn).numpy() # get forces/accelerations
        v = v + a*dt
        x = x + v*dt
        x = x % box_width
        xs.append(x)
        ts.append(ts[-1]+dt)
    return np.asarray(ts), np.stack(xs)

def simulate_3body(dt=0.5, steps=100):
    np.random.seed(1)
    x0 = np.asarray([[0.4, 0.3], [0.4, 0.7], [0.7, 0.5]])
    v0 = np.asarray([[0.017, -0.006], [-.012, -.012], [0.0, 0.017]])
    x1 = x0 + dt*v0
    return solve_ode_euler(x0, x1, dt, potential_fn=V_3body, steps=steps)

def simulate_gas(dt=1, N=500):
    np.random.seed(1)
    x0 = np.random.rand(N,2)*.8 + 0.1
    v0 = np.random.randn(N,2)*0
    x1 = x0 + dt*v0
    return solve_ode_euler(x0, x1, dt, potential_fn=V_gas)

In [None]:
t, x = simulate_gas(dt=1, N=500)

make_video(x, path='sim.mp4', interval=60)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
t, x = simulate_3body()
make_video(x, path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

## Recover the same dynamics by minimizing the action

In [None]:
def lagrangian(q, V_fn, m=1, g=1):
    (x, xdot) = q
    T = .5*m*xdot**2
    N = x.shape[-1] // 2
    V = V_fn(x.reshape(-1, N, 2))
    return (T.sum() - V.sum()) / (x.shape[0]*N)
  
def action(x, V_fn, dt=1):
    dx = x[1:] - x[:-1]
    wraps = (dx.abs() > 0.9)
    dx = -dx.sign()*(dx.abs()-1)*wraps + dx*(~wraps)
    xdot = dx / dt
    xdot = torch.cat([xdot, xdot[-1:]], axis=0)
    return lagrangian(q=(x, xdot), V_fn=V_fn).sum()

def get_path_between(path, steps, step_size, V_fn, dt=1, box_width=1, opt='sgd'):
    t = np.linspace(0, len(path.x)-1, len(path.x)) * dt
    optimizer = torch.optim.SGD(path.parameters(), lr=step_size, momentum=0.8) if opt=='sgd' else \
                torch.optim.Adam(path.parameters(), lr=step_size)
    xs = [path.x.clone().data]
    t0 = time.time()
    for i in range(steps):
        S = action(path.x, V_fn, dt)
        S.backward() ; path.x.grad.data[[0,-1]] *= 0
        optimizer.step() ; path.zero_grad()
        path.x.data = path.x.data % box_width # x is subject to modulo arithmetic

        if i % (steps//15) == 0:
            xs.append(path.x.clone().data)
            print('step={:04d}, S={:.3e} J*s, dt={:.1f}s'.format(i, S.item(), time.time()-t0))
            t0 = time.time()
    return t, path, xs

class PerturbedPath(torch.nn.Module):
    def __init__(self, x_true, N, sigma=0, shift=False):
        super(PerturbedPath, self).__init__()
        self.x_true = x_true
        x_noise = sigma*np.random.randn(*x_true.shape).clip(-1,1)
        x_noise[:1] = x_noise[-1:] = 0
        self.x_pert = x_pert = (x_true + x_noise).reshape(-1, N*2)
        if shift:
            x_pert_shift = np.concatenate([x_pert[5:-5,2:], x_pert[5:-5,:2]], axis=-1)
            self.x_pert[5:-5] = x_pert[5:-5] = x_pert_shift
            print(self.x_pert.shape)
        self.x = torch.nn.Parameter(torch.tensor(x_pert)) # [time, N*2]

## Gas simulation

In [None]:
dt = 1 ; N = 500
t_sim, x_sim = simulate_gas(dt=dt, N=N)
init_path = PerturbedPath(x_sim, N=N, sigma=1e-2) # [time, N*2]
t_min, path, xs_min = get_path_between(init_path, steps=100, step_size=5e2, V_fn=V_gas, dt=dt)

In [None]:
N = x_sim.shape[-2]
xs_before = xs_min[0].detach().numpy().reshape(-1,N,2)
xs_after = xs_min[-1].detach().numpy().reshape(-1,N,2)

k = 25
plt.figure(dpi=100)
plt.title('Ball {} horiz. velocity vs. time'.format(1 + k//2))
plt.plot((xs_before[1:] - xs_before[:-1]).reshape(-1,N*2)[...,k], '.-', label='Initial path')
plt.plot((xs_after[1:] - xs_after[:-1]).reshape(-1,N*2)[...,k], '.-', label='Minimum action')
plt.plot((x_sim[1:] - x_sim[:-1]).reshape(-1,N*2)[...,k], 'k-', label='Simulator')
plt.legend()
plt.show()

In [None]:
xs = xs_min[0].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=60, ms=10)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
xs = xs_min[-1].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=60, ms=10)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

## 3 body simulation

In [None]:
dt = 0.5 ; N = 3

t_sim, x_sim = simulate_3body(dt=dt)
init_path = PerturbedPath(x_sim, N=N, sigma=2e-2) # [time, N*2]
t_min, path, xs_min = get_path_between(init_path, steps=500, step_size=1e1, V_fn=V_gas, dt=dt, opt='sgd')

In [None]:
N = x_sim.shape[-2]
xs_before = xs_min[0].detach().numpy().reshape(-1,N,2)
xs_after = xs_min[-1].detach().numpy().reshape(-1,N,2)

plt.figure(dpi=100) ; k = 0
plt.title('Ball {} horiz. velocity vs. time'.format(1 + k//2))
plt.plot((xs_before[1:] - xs_before[:-1]).reshape(-1,N*2)[...,k], '.-', label='Initial path')
plt.plot((xs_after[1:] - xs_after[:-1]).reshape(-1,N*2)[...,k], '.-', label='Minimum action')
plt.plot((x_sim[1:] - x_sim[:-1]).reshape(-1,N*2)[...,k], 'k-', label='Simulator')
plt.legend()
plt.show()

In [None]:
xs = xs_min[0].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
xs = xs_min[-1].detach().numpy().reshape(-1,N,2)
make_video(xs, path='sim.mp4', interval=60, ms=20)
mp4 = open('sim.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=300 controls><source src="{}" type="video/mp4"></video>'.format(data_url))