Re-implements [this](https://github.com/kmkolasinski/deep-learning-notes/blob/master/seminars/2019-03-Neural-Ordinary-Differential-Equations/1.Demo_spiral.ipynb).

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from time import time
from node.fix_grid import FixGridODESolver, FixGridODESolverWithTrajectory, rk4_step_fn, euler_step_fn
from node.base import reverse_mode_derivative

In [2]:
data_size = 1000
batch_time = 20  # this seems to works the best ...
n_iters = 3000
batch_size = 16

true_y0 = tf.constant([[2., 0.]])
true_A = tf.constant([[-0.1, 2.0], [-2.0, -0.1]])

In [3]:
@tf.function
def f(x, t):
    return tf.matmul(x**3, true_A)

In [4]:
ode_solver = FixGridODESolver(rk4_step_fn, data_size)
ode_solver_with_traj = FixGridODESolverWithTrajectory(rk4_step_fn, data_size)

forward = ode_solver(f)
traj_forward = ode_solver_with_traj(f)

In [5]:
t0 = tf.constant(0.)
t1 = tf.constant(25.)
true_y0 = tf.constant(true_y0)
_, true_y = traj_forward(t0, t1, true_y0)
true_y = true_y.numpy().reshape([data_size, 2])
true_y

array([[ 2.        ,  0.        ],
       [ 1.979506  ,  0.39437142],
       [ 1.9493771 ,  0.77418923],
       ...,
       [-0.44173265,  0.28836665],
       [-0.44268936,  0.28397956],
       [-0.44359136,  0.27956814]], dtype=float32)

In [26]:
t0 = tf.constant(0)
t1 = tf.constant(25)
true_y0 = tf.constant(true_y0)
t00 = time()
yN = forward(tf.constant(0.), tf.constant(25.), true_y0)
t01 = time()
print(t01 - t00)
yN

0.07180643081665039


<tf.Tensor: id=327, shape=(1, 2), dtype=float32, numpy=array([[-0.44359136,  0.27956814]], dtype=float32)>

In [None]:
def plot_spiral(trajectories):
    for path in trajectories:
        print(path.shape)
        plt.plot(*path.T)
    plt.axis("equal")
    plt.xlabel("x")
    plt.ylabel("y")

plot_spiral([true_y])

In [None]:
# simple network which is used to learn trajectory
class ODEModel:
    def __init__(self):
        self.model = tf.keras.Sequential([
            tf.keras.layers.Dense(50, activation="tanh"),
            tf.keras.layers.Dense(2)])
        self.model.build([None, 2])

    def __call__(self, x, t):
        h = x**3
        return self.model(h)

In [None]:
network = ODEModel()
var_list = network.model.trainable_variables
optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=1e-4)

In [None]:
import time

@tf.function
def compute_gradients_and_update(batch_y0, batch_yN):
    """Takes start positions (x0, y0) and final positions (xN, yN)"""
    pred_y = ode_solver(f, 0, 25, batch_y0)
    with tf.GradientTape() as g:
        g.watch(pred_y)
        loss = tf.reduce_mean(tf.abs(pred_y - batch_yN))
    dLoss = g.gradient(loss, pred_y)
    dWeights = reverse_mode_derivative(
        ode_solver, network, var_list, 0, 25, pred_y, dLoss)
    optimizer.apply_gradients(zip(dWeights, var_list))
    return loss

In [None]:
def get_batch():
    """Returns initial point and last point over sampled frament of trajectory"""
    starts = np.random.choice(np.arange(data_size - batch_time - 1, dtype=np.int64), batch_size, replace=False)
    batch_y0 = true_y[starts] # (batch_size, 2) -> initial point
    batch_yN = true_y[starts + batch_time]
    return tf.constant(batch_y0), tf.constant(batch_yN)

In [None]:
from tqdm import tqdm

loss_history = []
for step in tqdm(range(n_iters)):
    batch_y0, batch_yN = get_batch()
    loss = compute_gradients_and_update(batch_y0, batch_yN)
    loss_history.append(loss.numpy())
    
    if step % 500 == 0:
        yN, states_history_model = ode_solver_with_traj(network, 0, 25, true_y0)
        # plot trajectories        
        plot_spiral([true_y, np.concatenate(states_history_model)])        
        plt.show()