In [1]:
%load_ext autoreload
%autoreload 2

import s3fs

import numpy as np
import torch
import xarray as xr

import cartopy.crs as ccrs

import plot
import model
import net
import optimize
import generate
import gpytorch

In [2]:
fs = s3fs.S3FileSystem(anon=True)
url = f's3://hrrrzarr/grid/HRRR_chunk_index.zarr'
file = s3fs.S3Map(url, s3=fs)

In [3]:
# get location data and drop other variables
locations = xr.open_dataset(file, engine='zarr')
locations = locations.drop_vars(list(locations.data_vars.keys()))

In [4]:
y = locations.y.values
x = locations.x.values

In [60]:
date = "20230801"


In [61]:
# get DPT data
fs = s3fs.S3FileSystem(anon=True)
url = f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/TMP/500mb'
file = s3fs.S3Map(url, s3=fs)

In [62]:
obs = xr.open_dataset(file, engine='zarr').load()
obs = obs.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
obs = obs.rename_vars({
    'TMP': 'obs'})

In [63]:
# get U velocity data
fs = s3fs.S3FileSystem(anon=True)
url =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/UGRD/500mb'
file = s3fs.S3Map(url, s3=fs)

In [64]:
U = xr.open_dataset(file, engine='zarr').load()
U = U.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
U = U.rename_vars({
    'UGRD': 'U'})

In [65]:
# get V velocity data
fs = s3fs.S3FileSystem(anon=True)
url =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/VGRD/500mb'
file = s3fs.S3Map(url, s3=fs)

In [66]:
V = xr.open_dataset(file, engine='zarr').load()
V = V.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
V = V.rename_vars({
    'VGRD': 'V'})

In [67]:
# merge
data = xr.merge([obs,U,V,locations])

In [68]:
Lambert_proj = ccrs.LambertConformal(central_longitude=262.5, central_latitude=38.5,
                                standard_parallels=[38.5,38.5],
                                globe=ccrs.Globe(semimajor_axis=6371229, 
                                                 semiminor_axis=6371229))

In [69]:
#plot.obs(data, Lambert_proj)

In [70]:
#plot.winds(data, data.U, data.V, Lambert_proj)

In [71]:
data_sub = data.isel(time = slice(0, None, 8), y=slice(None, None, 30), x=slice(None, None, 30))

In [72]:
#(l,d) = Lambert_proj.transform_point(-95,25,ccrs.Geodetic())
#(r,u) = Lambert_proj.transform_point(-50,35,ccrs.Geodetic())
#data_sub = data_sub.sel(x = slice(l,r), y = slice(d,u))

In [73]:
#plot.obs(data_sub, Lambert_proj)

In [74]:
data_sub.time.shape[0]

6

In [75]:
n = data_sub.time.shape[0]
t = torch.linspace(0,1,n)
x = torch.tensor(data_sub.x.values, dtype=torch.float32)/1e6
y = torch.tensor(data_sub.y.values, dtype=torch.float32).flip(dims=[0])/1e6
z = torch.tensor(data_sub.obs.values, dtype=torch.float32).flip(dims=[1]).reshape(-1)
z = (z - z.mean())/z.std()

In [76]:
T,Y,X = torch.meshgrid(t,y,x, indexing='ij')
coords = torch.stack([X,Y,T], dim = -1)
spacetime = coords.reshape(-1,3)
obs = z.reshape(coords.shape[0:3])

In [77]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

mesh = ax.imshow(
    obs[0, :, :],
    cmap="viridis")

cbar = plt.colorbar(mesh, ax=ax, orientation='vertical', pad=0.02, aspect=16, shrink=0.8)
cbar.set_label('Dew Point Temperature (°C)', fontsize=12)

def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.imshow(
        obs[frame,:, :],
        cmap="viridis")
        
    ax.set_title(f"observations at time point {frame}", fontsize=14)
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=t.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [78]:
flow = net.Flow(L = 4)
vel = net.Vel(L = 3)
gp_flow = model.GP_FLOW(flow, vel)

In [79]:
optimize.flow(spacetime, z, gp_flow, num_epochs=100)

Epoch: 5 - Likelihood: 0.686 - Learning Rates: [0.001, 0.001]
Epoch: 10 - Likelihood: 0.485 - Learning Rates: [0.001, 0.001]
Epoch: 15 - Likelihood: 0.286 - Learning Rates: [0.001, 0.001]
Epoch: 20 - Likelihood: 0.086 - Learning Rates: [0.001, 0.001]
Epoch: 25 - Likelihood: -0.097 - Learning Rates: [0.001, 0.001]
Epoch: 30 - Likelihood: -0.246 - Learning Rates: [0.001, 0.001]
Epoch: 35 - Likelihood: -0.344 - Learning Rates: [0.001, 0.001]
Epoch: 40 - Likelihood: -0.389 - Learning Rates: [0.001, 0.001]
Epoch: 45 - Likelihood: -0.405 - Learning Rates: [0.001, 0.001]
Epoch: 50 - Likelihood: -0.416 - Learning Rates: [0.001, 0.001]
Epoch: 55 - Likelihood: -0.434 - Learning Rates: [0.001, 0.001]
Epoch: 60 - Likelihood: -0.453 - Learning Rates: [0.001, 0.001]
Epoch: 65 - Likelihood: -0.464 - Learning Rates: [0.001, 0.001]
Epoch: 70 - Likelihood: -0.472 - Learning Rates: [0.001, 0.001]
Epoch: 75 - Likelihood: -0.480 - Learning Rates: [0.001, 0.001]
Epoch: 80 - Likelihood: -0.487 - Learning Rat

In [80]:
data_sub

In [81]:
def D_phi(xyt, flow):
    func = lambda xyt: torch.func.jacrev(flow)(xyt)
    return torch.vmap(func)(xyt)

def v_hat(xyt):
    Jacobians = D_phi(xyt, gp_flow.flow)
    Jacobians_t = Jacobians[..., 2]
    Jacobians_x = Jacobians[..., 0:2]

    v = torch.linalg.solve(-1*Jacobians_x, Jacobians_t)
    
    return v
vels = v_hat(coords.reshape(-1,3)).reshape(6, 36, 60, 2).detach()

In [82]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

frame = 0
step = 2
mesh = ax.pcolormesh(
    coords[frame,:,:,0],
    coords[frame,:,:,1],
    obs[frame, :, :],
    cmap="coolwarm")

Q = ax.quiver(
    coords[0,::step,::step,0],
    coords[0,::step,::step,1],
    
    vels[frame,::step,::step,0],
    vels[frame,::step,::step,1])


def update(frame):
    global mesh, Q
    mesh.remove()
    Q.remove
    mesh = ax.pcolormesh(
        coords[frame,:,:,0],
        coords[frame,:,:,1],
        obs[frame,:, :],
        cmap="coolwarm")
    
    Q = ax.quiver(
        coords[0,::step,::step,0],
        coords[0,::step,::step,1],
        
        vels[frame,::step,::step,0],
        vels[frame,::step,::step,1])
    
    
    return Q, mesh

ani = animation.FuncAnimation(
    fig,
    update,
    frames=t.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [83]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

frame=0
step = 2
Q = ax.quiver(
    coords[0,::step,::step,0],
    coords[0,::step,::step,1],
    
    data_sub.U[frame,::step,::step],
    data_sub.V[frame,::step,::step])


def update_quiver(frame):
    Q.set_UVC( data_sub.U[frame,::step,::step], data_sub.V[frame,::step,::step])
    ax.set_title(f"Wind Velocities at Hour {frame+1}")
    return Q,

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=n
)

plt.close(fig)

HTML(ani.to_jshtml())

In [192]:
n = data_sub.time.shape[0] - 1
t = torch.linspace(0, n, n + 1)/n

y = torch.tensor(data_sub.y.values, dtype= torch.float32)/1e6
x = torch.tensor(data_sub.x.values, dtype= torch.float32)/1e6

Y,X,T = torch.meshgrid(y, x, t, indexing='ij')
SpaceTime = torch.stack([Y, X, T], dim = -1)

In [None]:
SpaceTime[0:5,0:5,0,0]

In [193]:
import gpytorch

In [245]:
ker =  gpytorch.kernels.MaternKernel(nu = 5/2)

In [None]:
ker(SpaceTime[1,2,1].reshape(-1,3), SpaceTime[0,0,3].reshape(-1,3)).evaluate()

In [253]:
S = ker(SpaceTime.reshape(-1,3),SpaceTime.reshape(-1,3)).evaluate()

In [254]:
SS = S.reshape(22,36,6,22,36,6)

In [None]:
SS[1,2,1,0,0,3]

In [None]:
SS[1,2,4,0,0,4]

In [61]:
z = torch.tensor(data_sub.obs.metpy.convert_units('degC').values,dtype=torch.float32).reshape(-1)/100


In [62]:
Flow = net.Flow()
Model = model.GP_FLOW(Flow)

In [None]:
optimize.flow(z, time_space_tensor, Model, num_epochs=30)

In [70]:
Model.Flow.eval()
input = Model.Flow(time_space_tensor + 0.1).detach()

In [71]:
txy = torch.stack([T, X, Y], dim=-1).reshape(-1,3)
data = torch.stack([input[0,1,:,:,:],  input[0,0,:,:,:]], dim=-1).reshape(-1,2).detach()

In [None]:
data

In [67]:
import torch.nn as nn
class ff(nn.Module):
    def __init__(self, d=2, L=6, h=32):
        super(ff, self).__init__()
      
        layers = [nn.Linear(d+1, h) , nn.GELU()]   
        for _ in range(L - 1):
            layers += [nn.Linear(h, h) , nn.GELU()]
        layers.append( nn.Linear(h, d) )
        self.network = nn.Sequential(*layers)

     
    def forward(self, tx):
        t = tx[0]
        x = tx[1:]
        return x + t* self.network(tx)

In [45]:
flowy =ff()

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(flowy.parameters())
for epoch in range(1000):
    optimizer.zero_grad()
    out = torch.vmap(flowy)(txy)          
    loss = criterion(out, data)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(f"Epoch: {epoch} Loss: {loss.item():.4f}")

In [None]:
data_sub

In [48]:
import penalty
vel = lambda tx: penalty.v_hat(tx,flowy)
est = torch.vmap(vel)(txy).reshape(n+1,22,36,2).detach().numpy()
V_hat,U_hat = est[:,:,:,1], est[:,:,:,0]

In [None]:
plot.winds(data_sub, U_hat, V_hat, Lambert_proj)