In [3]:
import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import neuro_shooting.generic_integrator as generic_integrator
from neuro_shooting.shooting_models import AutoShootingIntegrandModelUpDown as UpDown
from neuro_shooting.shooting_models import AutoShootingIntegrandModelSimple as Simple
from neuro_shooting.shooting_models import AutoShootingIntegrandModelSecondOrder as SecOrder
from neuro_shooting.shooting_blocks import ShootingBlockBase as Base

In [4]:
gpu = 0
device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')

In [5]:
t_max = 100
data_size = 250
linear = True

t = torch.linspace(0., t_max, data_size).to(device)
true_y0 = torch.tensor([[0.6, 0.3]]).to(device)
true_A = torch.tensor([[-0.1, -1.0], [1.0, -0.1]]).to(device)

class Lambda(nn.Module):

    def forward(self, t, y):
        if linear:
            return torch.mm(y, true_A)
        else:
            return torch.mm(y**3, true_A)

stepsize = 0.5
integrator_options = dict()
integrator_options = {'step_size': stepsize}
rtol = 1e-8
atol = 1e-10
adjoint = False
integrator = generic_integrator.GenericIntegrator(integrator_library = 'odeint', 
                                                  integrator_name = 'rk4', 
                                                  use_adjoint_integration=adjoint, 
                                                  integrator_options=integrator_options, 
                                                  rtol=rtol, atol=atol)

       
with torch.no_grad():
    true_y = integrator.integrate(func=Lambda(), x0=true_y0, t=t)


def get_batch(batch_size, batch_time):
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype=np.int64), batch_size, replace=False)).to(device)
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    return batch_y0, batch_t, batch_y

def to_np(x):
    return x.detach().cpu().numpy()

In [6]:
def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)
        
makedirs('png/spiral')
    
def plot_trajectories(true_y, pred_y, sim_time, save=None, figsize=(16, 8)):

    plt.figure(figsize=figsize)

    if true_y is not None:
        if sim_time is None:
            sim_time = [None] * len(true_y)
        for o, t in zip(true_y, sim_time):
            o, t = to_np(o), to_np(t)
            plt.scatter(o[:, :, 0], o[:, :, 1], c=t, cmap=cm.plasma,label='observations (colored by time)')

    if pred_y is not None:
        for z in pred_y:
            z = to_np(z)
            plt.plot(z[:, :, 0], z[:, :, 1], lw=1.5, label="prediction")
        if save is not None:
            plt.savefig(save)

    plt.legend()
    plt.title('Trajectory: observed versus predicted')
    plt.xlabel('y_1')
    plt.ylabel('y_2')
    plt.show()


In [7]:
class Model(nn.Module):
    def __init__(self, pw=1.0, nr_of_particles=5):
        super(Model, self).__init__()
        
        self.int = UpDown(2, 'tanh', nr_of_particles=nr_of_particles, parameter_weight=pw)        
        self.blk = Base('shooting_block', shooting_integrand=self.int)
    
    def trajectory(self, batch_y0, batch_t):
        # batch_t defines time steps for trajectory
        # set time steps 
        self.blk.set_integration_time_vector(batch_t, suppress_warning=True)
        # run through shooting block
        out = self.blk(batch_y0)
        # reset integration time
        self.blk.set_integration_time(t_max)   
        return out
    
    def forward(self, batch_y0):
        pred_y,_,_,_ = self.blk(batch_y0)
        return pred_y


In [9]:
N_epochs = 100
N_particles = 5
pw = 0.5
batch_size = 10
batch_time = 25

model = Model(pw=pw, nr_of_particles=N_particles)
opt = torch.optim.Adam(model.parameters(), lr=0.1)

for epoch in range(1,N_epochs+1):


    batch_y0, batch_t, batch_y = get_batch(batch_size, batch_time)

    # zero-out gradients
    model.zero_grad()
    # get output from shooting block + model output
    pred_batch_terminal = model(batch_y0)
    # compute loss
    
    loss = torch.mean(torch.norm(pred_batch_terminal - batch_y[batch_time-1,:,:], dim=2))
    # backprop
    loss.backward()
    # take gradient step
    opt.step()

    # TODO: there's an error here! after calling model.trajectory, can't do forward pass of model anymore, doesn't thrown an error just freezes...
    test_freq = 100
    if epoch % test_freq == 0:
        print('Epoch {} | Total Loss {:.2f}'.format(epoch, loss.item()))
        viz_index = 0
        pred_batch_trajectory, _, _, _ = model.trajectory(batch_y0, batch_t)
        plot_trajectories([batch_y[:,viz_index,:,:]],[pred_batch_trajectory[:,viz_index,:,:]],[batch_t],
                                      save="./png/spiral/epoch{}.png".format(epoch), figsize=(16, 8))

Epoch 100 | Total Loss 0.01
