In [1]:
%load_ext autoreload
%autoreload 2

import s3fs
import zarr
import xarray as xr

import numpy as np
import torch

import cartopy.crs as ccrs



In [2]:
fs = s3fs.S3FileSystem(anon=True)
url = 's3://hrrrzarr/grid/HRRR_chunk_index.zarr'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})
locations = xr.open_dataset(store, engine='zarr')
locations = locations.drop_vars(list(locations.data_vars.keys()))

In [3]:
y = locations.y.values
x = locations.x.values

In [37]:
date = "20241225"


In [38]:
# get DPT data
fs = s3fs.S3FileSystem(anon=True)
url = f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/TMP/500mb'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})
obs = xr.open_dataset(store, engine='zarr').load()
obs = obs.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
obs = obs.rename_vars({
    'TMP': 'obs'})

In [39]:
'''fs = s3fs.S3FileSystem(anon=True)
url =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/UGRD/500mb'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})

U = xr.open_dataset(store, engine='zarr').load()
U = U.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
U = U.rename_vars({
    'UGRD': 'U'})'''

"fs = s3fs.S3FileSystem(anon=True)\nurl =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/UGRD/500mb'\nstore = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})\n\nU = xr.open_dataset(store, engine='zarr').load()\nU = U.rename({\n    'projection_x_coordinate': 'x',\n    'projection_y_coordinate': 'y'\n})\nU = U.rename_vars({\n    'UGRD': 'U'})"

In [40]:
'''fs = s3fs.S3FileSystem(anon=True)
url =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/VGRD/500mb'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})

V= xr.open_dataset(store, engine='zarr').load()
V = V.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
V = V.rename_vars({
    'VGRD': 'V'})'''

"fs = s3fs.S3FileSystem(anon=True)\nurl =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/VGRD/500mb'\nstore = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})\n\nV= xr.open_dataset(store, engine='zarr').load()\nV = V.rename({\n    'projection_x_coordinate': 'x',\n    'projection_y_coordinate': 'y'\n})\nV = V.rename_vars({\n    'VGRD': 'V'})"

In [41]:
# merge
data = xr.merge([obs,locations])

In [42]:
Lambert_proj = ccrs.LambertConformal(central_longitude=262.5, central_latitude=38.5,
                                standard_parallels=[38.5,38.5],
                                globe=ccrs.Globe(semimajor_axis=6371229, 
                                                 semiminor_axis=6371229))

In [43]:
data_sub = data.isel(time = slice(0, None, 8), y=slice(None, None, 20), x=slice(None, None, 20))

In [44]:
(l,d) = Lambert_proj.transform_point(-100,30,ccrs.Geodetic())
(r,u) = Lambert_proj.transform_point(-50,50,ccrs.Geodetic())
data_sub = data_sub.sel(x = slice(l,r), y = slice(d,u))

In [45]:
#plot.obs(data_sub, Lambert_proj)

In [46]:
n = data_sub.time.shape[0]
t = torch.linspace(0,1,n)
x = torch.tensor(data_sub.x.values, dtype=torch.float32)/1e6
y = torch.tensor(data_sub.y.values, dtype=torch.float32).flip(dims=[0])/1e6
Z = torch.tensor(data_sub.obs.values, dtype=torch.float32).flip(dims=[1]).reshape(-1)
Z = (Z - Z.mean())/Z.std()

In [47]:
T,Y,X = torch.meshgrid(t,y,x,indexing='ij')
coords = torch.stack([T,X,Y], dim = -1)
TXY = coords.reshape(-1,3)

obs = Z.reshape(coords.shape[0:3])

In [49]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

mesh = ax.imshow(
    obs[0, :, :],
    cmap="viridis")

cbar = plt.colorbar(mesh, ax=ax, orientation='vertical', pad=0.02, aspect=16, shrink=0.8)
cbar.set_label('Dew Point Temperature (°C)', fontsize=12)

def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.imshow(
        obs[frame,:, :],
        cmap="viridis")
        
    ax.set_title(f"observations at time point {frame}", fontsize=14)
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=t.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [50]:
import net, model
flow = net.Flow(L = 5)
gp_flow = model.GP_FLOW(TXY, Z, flow)


In [51]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [52]:
import optimize
optimize.flow(gp_flow, num_epochs=120)

Epoch: 5 - Likelihood: 0.790 - Learning Rates: [0.001]
Epoch: 10 - Likelihood: 0.561 - Learning Rates: [0.001]
Epoch: 15 - Likelihood: 0.360 - Learning Rates: [0.001]
Epoch: 20 - Likelihood: 0.198 - Learning Rates: [0.001]
Epoch: 25 - Likelihood: 0.105 - Learning Rates: [0.001]
Epoch: 30 - Likelihood: 0.061 - Learning Rates: [0.001]
Epoch: 35 - Likelihood: 0.049 - Learning Rates: [0.001]
Epoch: 40 - Likelihood: 0.031 - Learning Rates: [0.001]
Epoch: 45 - Likelihood: 0.000 - Learning Rates: [0.001]
Epoch: 50 - Likelihood: -0.024 - Learning Rates: [0.001]
Epoch: 55 - Likelihood: -0.054 - Learning Rates: [0.001]
Epoch: 60 - Likelihood: -0.093 - Learning Rates: [0.001]
Epoch: 65 - Likelihood: -0.143 - Learning Rates: [0.001]
Epoch: 70 - Likelihood: -0.193 - Learning Rates: [0.001]
Epoch: 75 - Likelihood: -0.216 - Learning Rates: [0.001]
Epoch: 80 - Likelihood: -0.240 - Learning Rates: [0.001]


In [53]:
import penalty
vels = penalty.v_hat(TXY, gp_flow.flow)
vels = vels.reshape(6, 42, 49, 2).detach()


In [54]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
frame=0
s = 2
Q = ax.quiver(    
    coords[0,::s,::s,1],
    coords[0,::s,::s,2],
    vels[frame,::s,::s,0],
    vels[frame,::s,::s,1])


def update_quiver(frame):
    Q.set_UVC( vels[frame,::s,::s,0], vels[frame,::s,::s,1])
    ax.set_title(f"Wind Velocities at Hour {frame+1}")
    return Q,

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=n
)

plt.close(fig)

HTML(ani.to_jshtml())