In [53]:
%load_ext autoreload
%autoreload 2

import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [54]:
n = 5
x = torch.linspace(-1,1,25)
y = torch.linspace(1,-1,25)
t = torch.linspace(0,1,n)

In [55]:
T,Y,X = torch.meshgrid(t,y,x, indexing='ij')
coords = torch.stack([X,Y,T], dim = -1)
spacetime = coords.reshape(-1,3)
#spacetime_list = [coords[i,:,:,:].reshape(-1,3) for i in range(n)]

In [56]:
def rot_flow(xyt):    
    
    x = xyt[..., 0:2]  
    t = xyt[..., 2]  

    cos_t = torch.cos(t)
    sin_t = -torch.sin(t) 

    x1_new = cos_t * x[..., 0] + sin_t * x[..., 1] 
    x2_new = -sin_t * x[..., 0] + cos_t * x[..., 1]  

    psi_t_x = torch.stack([x1_new, x2_new], dim = -1)

    return psi_t_x

In [57]:
import model, gpytorch
points = rot_flow(spacetime)
ker =  gpytorch.kernels.MaternKernel(nu = 1/2)
m = torch.zeros(spacetime.shape[0])
C = ker(points, points)

In [58]:
from gpytorch.distributions import MultivariateNormal
z = MultivariateNormal(m, C).sample()

In [59]:
obs = z.reshape(coords.shape[0:3])

In [60]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

mesh = ax.imshow(
    obs[0, :, :],
    cmap="viridis")

cbar = plt.colorbar(mesh, ax=ax, orientation='vertical', pad=0.02, aspect=16, shrink=0.8)
cbar.set_label('Dew Point Temperature (°C)', fontsize=12)

def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.imshow(
        obs[frame,:, :],
        cmap="viridis")
        
    ax.set_title(f"observations at time point {frame}", fontsize=14)
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=t.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [98]:
import net
flow = net.Flow(L = 3)
vel = net.Vel(L = 2)
gp_flow = model.GP_FLOW(flow, vel)

In [99]:
import optimize
optimize.flow(spacetime, z, gp_flow, num_epochs=150)

Epoch: 5 - Likelihood: 0.812 - Learning Rates: [0.001, 0.001]
Epoch: 10 - Likelihood: 0.614 - Learning Rates: [0.001, 0.001]
Epoch: 15 - Likelihood: 0.441 - Learning Rates: [0.001, 0.001]
Epoch: 20 - Likelihood: 0.268 - Learning Rates: [0.001, 0.001]
Epoch: 25 - Likelihood: 0.158 - Learning Rates: [0.001, 0.001]
Epoch: 30 - Likelihood: 0.082 - Learning Rates: [0.001, 0.001]
Epoch: 35 - Likelihood: 0.014 - Learning Rates: [0.001, 0.001]
Epoch: 40 - Likelihood: -0.067 - Learning Rates: [0.001, 0.001]
Epoch: 45 - Likelihood: -0.155 - Learning Rates: [0.001, 0.001]
Epoch: 50 - Likelihood: -0.294 - Learning Rates: [0.001, 0.001]
Epoch: 55 - Likelihood: -0.451 - Learning Rates: [0.001, 0.001]
Epoch: 60 - Likelihood: -0.598 - Learning Rates: [0.001, 0.001]
Epoch: 65 - Likelihood: -0.797 - Learning Rates: [0.001, 0.001]
Epoch: 70 - Likelihood: -0.975 - Learning Rates: [0.001, 0.001]
Epoch: 75 - Likelihood: -1.160 - Learning Rates: [0.001, 0.001]
Epoch: 80 - Likelihood: -1.337 - Learning Rates:

In [100]:
flows = gp_flow.flow(coords.reshape(-1,3)).reshape(n,25,25,2).detach()

In [101]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(-2.5,2.5)
ax.set_ylim(-2.5,2.5)

frame = 0
mesh = ax.scatter(
    flows[frame,:,:,0],
    flows[frame,:,:,1],
    s=10,  
    color='red')


def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.scatter(
        flows[frame,:,:,0],
        flows[frame,:,:,1],
        s=10,  
        color='red')
    
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=t.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [102]:
def D_phi(xyt, flow):
    func = lambda xyt: torch.func.jacrev(flow)(xyt)
    return torch.vmap(func)(xyt)

def v_hat(xyt):
    Jacobians = D_phi(xyt, gp_flow.flow)
    Jacobians_t = Jacobians[..., 2]
    Jacobians_x = Jacobians[..., 0:2]

    v = torch.linalg.solve(-1*Jacobians_x, Jacobians_t)
    
    return v
gp_flow.flow.eval()
vels = v_hat(coords.reshape(-1,3)).reshape(n,25,25,2).detach()

In [103]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
frame=0
Q = ax.quiver(
    flows[0,:,:,0],
    flows[0,:,:,1],
    
    vels[frame,:,:,0],
    vels[frame,:,:,1])


def update_quiver(frame):
    Q.set_UVC( vels[frame,:,:,0], vels[frame,:,:,1])
    ax.set_title(f"Wind Velocities at Hour {frame+1}")
    return Q,

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=n
)

plt.close(fig)

HTML(ani.to_jshtml())