In [76]:
%load_ext autoreload
%autoreload 2

import torch
import gpytorch
import model

from torchdiffeq import odeint

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [77]:
x = torch.linspace(-1,1,25)
y = torch.linspace(1,-1,25)
t = torch.linspace(0,1,10)

In [78]:
T,Y,X = torch.meshgrid(t,y,x,indexing='ij')
coords = torch.stack([T,X,Y], dim = -1)
TXY = coords.reshape(-1,3)


In [79]:
def rot_vel(t,xy):  
    x = xy[...,0]
    y = xy[...,1]
    return  torch.stack([y, -x], dim = -1)

def rot_flow(t,xy):
    t0 = torch.zeros(1) + 1e-6
    t = torch.cat([t,t0])
    solution = odeint(rot_vel, xy, t, method='dopri5')[-1]

    return solution

In [80]:
T = torch.unique(TXY[:,0])
XY =  torch.stack(torch.meshgrid(x,y,indexing='xy'), dim = -1).reshape(-1,2)

def rot_flow(t,xy):
    t0 = torch.zeros(1) + 1e-6
    t = torch.cat([t,t0])
    solution = odeint(rot_vel, xy, t, method='dopri5')[-1]
    return solution
        
points = XY
for t in T[1:]:
    points = torch.cat([points, rot_flow(t.unsqueeze(0),XY)], dim = 0)


In [81]:
ker =  gpytorch.kernels.MaternKernel(nu = 1/2)
m = torch.zeros(points.shape[0])
C = ker(points, points)

In [82]:
Z = gpytorch.distributions.MultivariateNormal(m, C).sample()

In [83]:
obs = Z.reshape(coords.shape[0:3])

In [84]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

mesh = ax.imshow(
    obs[0, :, :],
    cmap="viridis")

cbar = plt.colorbar(mesh, ax=ax, orientation='vertical', pad=0.02, aspect=16, shrink=0.8)
cbar.set_label('Dew Point Temperature (°C)', fontsize=12)

def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.imshow(
        obs[frame,:, :],
        cmap="viridis")
        
    ax.set_title(f"observations at time point {frame}", fontsize=14)
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=T.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [85]:
import net
flow = net.Flow(L = 4)
gp_flow = model.GP_FLOW(TXY, Z, flow)

In [86]:
import optimize
optimize.flow(gp_flow, num_epochs=75)

Epoch: 5 - Likelihood: 0.753 - Learning Rates: [0.001]
Epoch: 10 - Likelihood: 0.574 - Learning Rates: [0.001]
Epoch: 15 - Likelihood: 0.416 - Learning Rates: [0.001]
Epoch: 20 - Likelihood: 0.282 - Learning Rates: [0.001]
Epoch: 25 - Likelihood: 0.128 - Learning Rates: [0.001]
Epoch: 30 - Likelihood: -0.100 - Learning Rates: [0.001]
Epoch: 35 - Likelihood: -0.268 - Learning Rates: [0.001]
Epoch: 40 - Likelihood: -0.424 - Learning Rates: [0.001]
Epoch: 45 - Likelihood: -0.591 - Learning Rates: [0.001]
Epoch: 50 - Likelihood: -0.744 - Learning Rates: [0.001]
Epoch: 55 - Likelihood: -0.907 - Learning Rates: [0.001]
Epoch: 60 - Likelihood: -1.047 - Learning Rates: [0.001]
Epoch: 65 - Likelihood: -1.248 - Learning Rates: [0.001]
Epoch: 70 - Likelihood: -1.439 - Learning Rates: [0.001]
Epoch: 75 - Likelihood: -1.580 - Learning Rates: [0.001]


In [87]:
import penalty
vels = penalty.v_hat(TXY, gp_flow.flow)

In [88]:
vels = vels.reshape(10,25,25,2).detach()

In [89]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
frame=0
Q = ax.quiver(    
    coords[0,:,:,1],
    coords[0,:,:,2],
    vels[frame,:,:,0],
    vels[frame,:,:,1])


def update_quiver(frame):
    Q.set_UVC( vels[frame,:,:,0], vels[frame,:,:,1])
    ax.set_title(f"Wind Velocities at Hour {frame+1}")
    return Q,

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=3
)

plt.close(fig)

HTML(ani.to_jshtml())