<a href="https://colab.research.google.com/github/paddy-la/LagrangianNN/blob/main/PL_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.36-cp36-none-linux_x86_64.whl
!pip install --upgrade jax  # install jax

Collecting jaxlib==0.1.36
[?25l  Downloading https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.36-cp36-none-linux_x86_64.whl (48.2MB)
[K     |████████████████████████████████| 48.2MB 66kB/s 
Installing collected packages: jaxlib
  Found existing installation: jaxlib 0.1.57+cuda101
    Uninstalling jaxlib-0.1.57+cuda101:
      Successfully uninstalled jaxlib-0.1.57+cuda101
Successfully installed jaxlib-0.1.36
Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.2.6)


In [1]:
!pip install --upgrade -q jax==0.1.55 jaxlib==0.1.36
!pip install -U -q Pillow moviepy proglog

[?25l[K     |█▎                              | 10kB 24.7MB/s eta 0:00:01[K     |██▋                             | 20kB 14.7MB/s eta 0:00:01[K     |███▉                            | 30kB 12.9MB/s eta 0:00:01[K     |█████▏                          | 40kB 12.1MB/s eta 0:00:01[K     |██████▍                         | 51kB 8.2MB/s eta 0:00:01[K     |███████▊                        | 61kB 8.5MB/s eta 0:00:01[K     |█████████                       | 71kB 8.7MB/s eta 0:00:01[K     |██████████▎                     | 81kB 8.9MB/s eta 0:00:01[K     |███████████▋                    | 92kB 9.0MB/s eta 0:00:01[K     |████████████▉                   | 102kB 7.9MB/s eta 0:00:01[K     |██████████████▏                 | 112kB 7.9MB/s eta 0:00:01[K     |███████████████▍                | 122kB 7.9MB/s eta 0:00:01[K     |████████████████▊               | 133kB 7.9MB/s eta 0:00:01[K     |██████████████████              | 143kB 7.9MB/s eta 0:00:01[K     |███████████████████▎   

In [3]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [2]:
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from functools import partial
import numpy as np
import matplotlib.pyplot as plt

def lagrangian(q, q_dot, m1, m2, l1, l2, g):
   
    t1, t2 = q     # theta 1 and theta 2
    w1, w2 = q_dot # omega 1 and omega 2

    # kinetic energy (T)
    T1 = 0.5 * m1 * (l1 * w1)**2
    T2 = 0.5 * m2 * ((l1 * w1)**2 + (l2 * w2)**2 +
                    2 * l1 * l2 * w1 * w2 * jnp.cos(t1 - t2))
    T = T1 + T2
  
    # potential energy (V)
    y1 = -l1 * jnp.cos(t1)
    y2 = y1 - l2 * jnp.cos(t2)
    V = m1 * g * y1 + m2 * g * y2

    return T - V


def f_analytical(state, m1=1, m2=1, l1=1, l2=1, g=9.8):
    t1, t2, w1, w2 = state

    a1 = (l2 / l1) * (m2 / (m1 + m2)) * jnp.cos(t1 - t2)
    a2 = (l1 / l2) * jnp.cos(t1 - t2)

    f1 = -(l2 / l1) * (m2 / (m1 + m2)) * (w2 ** 2) * jnp.sin(t1 - t2) - (g / l1) * jnp.sin(t1)
    f2 = (l1 / l2) * (w1 ** 2) * jnp.sin(t1 - t2) - (g / l2) * jnp.sin(t2)

    g1 = (f1 - a1 * f2) / (1 - a1 * a2)
    g2 = (f2 - a2 * f1) / (1 - a1 * a2)

    return jnp.stack([w1, w2, g1, g2])  # returns column vector of time derivatives of theta 1 and 2


def equation_of_motion(lagrangian, state, t=None):
    q, q_t = jnp.split(state, 2)  # the initial configuration
    q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
            @ (jax.grad(lagrangian, 0)(q, q_t)
               - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
    return jnp.concatenate([q_t, q_tt]) # so what happens in essence is it takes in the intial arguments and using the E-L
    # equations produces the time derivatives 


def solve_lagrangian(lagrangian, initial_state, **kwargs):
    @partial(jax.jit, backend='cpu')
    def f(initial_state):
        return odeint(partial(equation_of_motion, lagrangian), initial_state, **kwargs)
    return f(initial_state)

@partial(jax.jit, backend='cpu')
def solve_autograd(initial_state, times, m1=1, m2=1, l1=1, l2=1, g=9.8):
    L = partial(lagrangian, m1=m1, m2=m2, l1=l1, l2=l2, g=g)
    # Specifies the conditions of the system onto the lagrangian then solves for times
    return solve_lagrangian(L, initial_state, t=times, rtol=1e-10, atol=1e-10)

@partial(jax.jit, backend='cpu')
def solve_analytical(initial_state, times):
    # Just solves the exact analytical equations for different time steps
    return odeint(f_analytical, initial_state, t=times, rtol=1e-10, atol=1e-10)


In [4]:
# choose an initial state
x0 = np.array([3*np.pi/7, 3*np.pi/4, 0, 0], dtype=np.float32)
noise = np.random.RandomState(0).randn(x0.size)
t = np.linspace(0, 40, num=401, dtype=np.float32) # evenly spaced time steps 

In [5]:
# compute dynamics analytically
x_analytical = solve_analytical(x0, t)
noise = np.random.RandomState(0).randn(x0.size)
noise_1, noise_2 = 1e-10, 1e-11
x_perturbed_1 = solve_analytical(x0 + noise_1 * noise, t)
x_perturbed_2 = solve_analytical(x0 + noise_2 * noise, t)

In [6]:
# compute dynamics with the lagrangian 
x_autograd = jax.device_get(solve_autograd(x0, t))

In [15]:
# generating training data 

time_step = 1
N=1500
# Use the x0 generated earlier 
t1=np.arange(N, dtype=np.float32) # t1 is an array evenly spaved from 0 to 1499
t2= np.arange(N,2*N,dtype=np.float32)

x_train = solve_analytical(x0,t1)
xt_train = jax.vmap(f_analytical)(x_train) # finds the derivative of the inital state using the analytical method 

x_test = solve_analytical(x0,t2)
xt_test = jax.vmap(f_analytical)(x_test)



In [14]:
# setting up to build the NN 

def normalize_dp(state):
  # rescales the coordinates to the minus pi to pi range 
  return jnp.concatenate([(state[:2] + np.pi) % (2 * np.pi) - np.pi, state[2:]])

def lagrangian_paramteric(params): # we replace the lagrangian we know with the black box lagrangian 
  def lagrangian(q,q_t):
    assert q.shape == (2,)
    state = normalise_dp(jnp.concatenate([q,q_t])) 
    return jnp.squeeze(nn_forward_fn(params, state), axis=-1) # look at the idea of np.squeeze online 
  return lagrangian

# create the loss function for the problem 

@jax.jit
def loss(params, batch, time_step=None):
  state, targets = batch # the batch consists of state = x_train and targets = xt_train
  if time_step is not None:
    f = partial(equation_of_motion, learned_lagrangian(params))
    preds = jax.vmap(partial(rk4_step, f, t=0.0, h=time_step))(state) # we integrate the time derivates from the E-L 
    # equation according to the black box lagranian and obtain the next step state 
  else:  # time_step is asserted = 0.001 and never changed so surely we always got through the if statement ? 
    preds = jax.vmap(partial(equation_of_motion, learned_lagrangian(params)))(state)
  return jnp.mean((preds - targets) ** 2)


# build the neural network model 5 layers starting with 128 nodes 
init_random_params, nn_forward_fn = stax.serial(
    stax.Dense(128),
    stax.Softplus,
    stax.Dense(128),
    stax.Softplus,
    stax.Dense(1),
)



In [None]:
@jax.jit
def update_derivative(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, jax.grad(loss)(params, batch, None), opt_state)

# I omitted update_timestep as we don't use it 


In [16]:
%%time

rng = jax.random.PRNGKey(0)
init_params = init_random_params(rng, (-1, 4)) # chose some random parameters 

batch_size = 100
test_every = 10
num_batches = 1500

train_losses = []
test_losses = []

# adam w learn rate decay
opt_init, opt_update, get_params = optimizers.adam(
    lambda t: jnp.select([t < batch_size*(num_batches//3),
                          t < batch_size*(2*num_batches//3),
                          t > batch_size*(2*num_batches//3)],
                         [1e-3, 3e-4, 1e-4])) # select just looks at the condition list and selects from the choices 
opt_state = opt_init(init_params)

for iteration in range(batch_size*num_batches + 1):
  if iteration % batch_size == 0:
    params = get_params(opt_state)
    train_loss = loss(params, (x_train, xt_train))
    train_losses.append(train_loss)
    test_loss = loss(params, (x_test, xt_test))
    test_losses.append(test_loss)
    if iteration % (batch_size*test_every) == 0:
      print(f"iteration={iteration}, train_loss={train_loss:.6f}, test_loss={test_loss:.6f}")
  opt_state = update_derivative(iteration, opt_state, (x_train, xt_train))

params = get_params(opt_state)

NameError: ignored