In [598]:
%load_ext autoreload
%autoreload 2

import torch
import gpytorch
import model

from torchdiffeq import odeint

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [654]:
x = torch.linspace(-1,1,30)
y = torch.linspace(1,-1,30)
t = torch.linspace(0,0.25,5)

In [655]:
T,Y,X = torch.meshgrid(t,y,x,indexing='ij')
coords = torch.stack([T,X,Y], dim = -1)
TXY = coords.reshape(-1,3)


In [656]:
def rot_vel(t, xy, A=1.0, omega=2 * 3.14159, epsilon=0.25):
    x = xy[..., 0]
    y = xy[..., 1]

    # Evaluate a and b at t = 0
    a = epsilon * torch.sin(torch.tensor(0.0))  # sin(0) = 0
    b = 1 - 2 * a  # So b remains 1

    f = a * x**2 + b * x
    df_dx = 2 * a * x + b

    u = -3.14159 * A * torch.sin(3.14159 * f) * torch.cos(3.14159 * y)
    v = 3.14159 * A * torch.cos(3.14159 * f) * torch.sin(3.14159 * y) * df_dx

    return torch.stack([u, v], dim=-1)


def rot_flow(t,xy):
    t0 = torch.zeros(1) + 1e-6
    t = torch.cat([t,t0])
    solution = odeint(rot_vel, xy, t, method='dopri5')[-1]

    return solution

In [657]:
T = torch.unique(TXY[:,0])
XY =  torch.stack(torch.meshgrid(x,y,indexing='xy'), dim = -1).reshape(-1,2)

def rot_flow(t,xy):
    t0 = torch.zeros(1) + 1e-6
    t = torch.cat([t,t0])
    solution = odeint(rot_vel, xy, t, method='dopri5')[-1]
    return solution
        
points = XY
for t in T[1:]:
    points = torch.cat([points, rot_flow(t.unsqueeze(0),XY)], dim = 0)


In [658]:
ker =  gpytorch.kernels.MaternKernel(nu = 3/2)
ker.lengthscale = 0.5
m = torch.zeros(points.shape[0])
C = ker(points, points)

In [659]:
Z = gpytorch.distributions.MultivariateNormal(m, C).sample()

In [660]:
obs = Z.reshape(coords.shape[0:3])

In [661]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

mesh = ax.imshow(
    obs[0, :, :],
    cmap="viridis")

cbar = plt.colorbar(mesh, ax=ax, orientation='vertical', pad=0.02, aspect=16, shrink=0.8)
cbar.set_label('Dew Point Temperature (°C)', fontsize=12)

def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.imshow(
        obs[frame,:, :],
        cmap="viridis")
        
    ax.set_title(f"observations at time point {frame}", fontsize=14)
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=T.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [662]:
Z_tensor = obs.flatten(1)

In [679]:
import net
flow = net.Flow(L = 10)
vel = net.Vel(L = 8)
gp_flow = model.GP_FLOW(T, XY, Z_tensor, TXY, Z, flow, vel)

In [680]:
import gc
gc.collect()


6352

In [681]:
torch.cuda.empty_cache()

In [682]:
import optimize
with gpytorch.settings.fast_computations(log_prob=False):
    optimize.initial(gp_flow, num_epochs=400)

Epoch: 1 - Likelihood: 0.94
Epoch: 2 - Likelihood: 0.87
Epoch: 3 - Likelihood: 0.82
Epoch: 4 - Likelihood: 0.79
Epoch: 5 - Likelihood: 0.76
Epoch: 6 - Likelihood: 0.73
Epoch: 7 - Likelihood: 0.69
Epoch: 8 - Likelihood: 0.65
Epoch: 9 - Likelihood: 0.62
Epoch: 10 - Likelihood: 0.59
Epoch: 11 - Likelihood: 0.56
Epoch: 12 - Likelihood: 0.52
Epoch: 13 - Likelihood: 0.49
Epoch: 14 - Likelihood: 0.46
Epoch: 15 - Likelihood: 0.43
Epoch: 16 - Likelihood: 0.40
Epoch: 17 - Likelihood: 0.37
Epoch: 18 - Likelihood: 0.35
Epoch: 19 - Likelihood: 0.33
Epoch: 20 - Likelihood: 0.31
Epoch: 21 - Likelihood: 0.28
Epoch: 22 - Likelihood: 0.27
Epoch: 23 - Likelihood: 0.25
Epoch: 24 - Likelihood: 0.23
Epoch: 25 - Likelihood: 0.22
Epoch: 26 - Likelihood: 0.21
Epoch: 27 - Likelihood: 0.20
Epoch: 28 - Likelihood: 0.18
Epoch: 29 - Likelihood: 0.17
Epoch: 30 - Likelihood: 0.16
Epoch: 31 - Likelihood: 0.14
Epoch: 32 - Likelihood: 0.14
Epoch: 33 - Likelihood: 0.12
Epoch: 34 - Likelihood: 0.10
Epoch: 35 - Likelihood:

In [683]:
import penalty
vels = gp_flow.vel(TXY).reshape(5,30,30,2).detach()


In [687]:
true_vels = rot_vel(T[0], XY)
for t in T[1:]:
    true_vels = torch.cat([true_vels, rot_vel(t.unsqueeze(0),XY)], dim = 0)

In [688]:
true_vels = true_vels.reshape(5,30,30,2)

In [691]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
frame=0
s = 1
Q1 = ax.quiver(    
    coords[0,::s,::s,1],
    coords[0,::s,::s,2],
    vels[frame,::s,::s,0],
    vels[frame,::s,::s,1], color = "red")

'''Q2 = ax.quiver(    
    coords[0,::s,::s,1],
    coords[0,::s,::s,2],
    true_vels[frame,::s,::s,0],
    true_vels[frame,::s,::s,1], color = "blue")'''


def update_quiver(frame):
    #Q1.set_UVC( vels[frame,::s,::s,0], vels[frame,::s,::s,1])
    Q2.set_UVC( true_vels[frame,::s,::s,0], true_vels[frame,::s,::s,1])
    ax.set_title(f"Wind Velocities at Hour {frame+1}")
    return Q2

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=5
)

plt.close(fig)

HTML(ani.to_jshtml())