In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import sys
import os
import pickle
import torch
import gpytorch
from IPython.display import HTML




# Allow imports from the parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from main import tools, ode, model, simulate, net

In [2]:

def vel(t, xy, A=1.0, omega= 2 * 3.14159, epsilon=0.25):
        x = xy[..., 0]  + 0.5
        y = xy[..., 1]

        a = epsilon * torch.sin(torch.tensor(0))
        b = 1 - 2 * a  

        f = a * x**2 + b * x
        df_dx = 2 * a * x + b

        u = -3.14159 * A * torch.sin(3.14159 * f) * torch.cos(3.14159 * y)
        v = 3.14159 * A * torch.cos(3.14159 * f) * torch.sin(3.14159 * y) * df_dx

        return torch.stack([u, v], dim=-1)/7.5

x = torch.linspace(-1,1,50)
y = torch.linspace(1,-1,50)
t = torch.linspace(0,1,5)

X,Y = torch.meshgrid(x,y,indexing='xy')
XY = torch.stack([X,Y], dim = -1).reshape(-1,2)
T = t
  

In [3]:
T_repeated = T.repeat_interleave(XY.shape[0]) 
XY_tiled = XY.repeat(T.shape[0], 1)  
TXY = torch.cat([T_repeated.unsqueeze(1), XY_tiled], dim=1)

In [4]:
flow = ode.Flow(vel)
spaceTimeKernel = model.SpaceTimeKernel(l0 = 2, l1 = 0.1, l2 = 0.1)
gpFlow = model.GPFlow(spaceTimeKernel, flow)
gpFlow.eval();

In [5]:
with gpytorch.settings.prior_mode(True):
    Z = gpFlow(TXY ).sample()
Z = Z.reshape(T.shape[0], -1)

In [40]:
data = model.data(T,XY,Z, grid_size = 1, k0 = 0, k1 = 1)

In [41]:
f = data.plot_observations(frame=3, gif = True)
HTML(f.to_html5_video())

In [42]:
flow = net.Flow(L = 4)
spaceTimeKernel = model.SpaceTimeKernel(l0 = 2, l1 = 0.1, l2 = 0.1).kernel
gpFlow = model.GPFlow(spaceTimeKernel, flow)
gpFlow.eval();


data.device = torch.device('cuda:0')
gpFlow = gpFlow.to(data.device)

precomputed_data = []
for i in range(1, data.n):
    for cell in data.cells:
        TXY0, Z0 = data.conditining(cell, i)
        TXY1, Z1 = data.prediction(cell, i)
        precomputed_data.append((TXY0, Z0, TXY1, Z1))


optimizer = torch.optim.Adam([
        {'params': gpFlow.flow.parameters(), 'lr':  0.01},
        {'params': gpFlow.kernel.parameters(), 'lr': 0.1},
        {'params': gpFlow.likelihood.parameters(), 'lr': 0.1},
        ])
mll = gpytorch.mlls.ExactMarginalLogLikelihood(gpFlow.likelihood, gpFlow)



In [None]:
with gpytorch.settings.detach_test_caches(state=False),\
        gpytorch.settings.cholesky_max_tries(7),\
        gpytorch.settings.fast_computations(log_prob=False, 
                                    covar_root_decomposition=False, 
                                    solves=False):
            

    for epoch in range(1, 100 + 1):
        optimizer.zero_grad()
        gpFlow.train()
        gpFlow.likelihood.train()
        
        loss = 0.0
        for TXY0, Z0, TXY1, Z1 in precomputed_data:
            TXY0, Z0 = TXY0.to(data.device), Z0.to(data.device)
            TXY1, Z1 = TXY1.to(data.device), Z1.to(data.device)
            gpFlow.set_train_data(TXY0, Z0, strict=False)
            gpFlow.eval()
            gpFlow.likelihood.eval()
            loss += -mll(gpFlow(TXY1), Z1) / len(precomputed_data)
        
        loss.backward()
        optimizer.step()
        
        print(f"Epoch: {epoch} - Likelihood: {loss.item():.3f}")


Epoch: 1 - Likelihood: 0.899
Epoch: 2 - Likelihood: 0.848
Epoch: 3 - Likelihood: 0.792
Epoch: 4 - Likelihood: 0.744
Epoch: 5 - Likelihood: 0.703
Epoch: 6 - Likelihood: 0.661
Epoch: 7 - Likelihood: 0.623
Epoch: 8 - Likelihood: 0.583
Epoch: 9 - Likelihood: 0.543
Epoch: 10 - Likelihood: 0.505
Epoch: 11 - Likelihood: 0.465
Epoch: 12 - Likelihood: 0.423
Epoch: 13 - Likelihood: 0.379
Epoch: 14 - Likelihood: 0.338
Epoch: 15 - Likelihood: 0.295
Epoch: 16 - Likelihood: 0.254
Epoch: 17 - Likelihood: 0.212
Epoch: 18 - Likelihood: 0.170
Epoch: 19 - Likelihood: 0.125
Epoch: 20 - Likelihood: 0.082
Epoch: 21 - Likelihood: 0.039
Epoch: 22 - Likelihood: -0.006
Epoch: 23 - Likelihood: -0.051
Epoch: 24 - Likelihood: -0.096
Epoch: 25 - Likelihood: -0.141
Epoch: 26 - Likelihood: -0.188
Epoch: 27 - Likelihood: -0.235
Epoch: 28 - Likelihood: -0.280
Epoch: 29 - Likelihood: -0.318
Epoch: 30 - Likelihood: -0.347
Epoch: 31 - Likelihood: -0.404
Epoch: 32 - Likelihood: -0.449
Epoch: 33 - Likelihood: -0.488
Epoch: 

In [44]:
data.flow = gpFlow.flow.cpu()

In [47]:
indices = torch.randperm(data.m)
f = data.plot_vel(indices, frame=0,scale = 8, color="blue", gif = True)
HTML(f.to_html5_video())