In [1]:
%load_ext autoreload
%autoreload 2

import s3fs
import zarr
import xarray as xr

import numpy as np
import torch
import gpytorch
import cartopy.crs as ccrs



In [2]:
fs = s3fs.S3FileSystem(anon=True)
url = 's3://hrrrzarr/grid/HRRR_chunk_index.zarr'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})
locations = xr.open_dataset(store, engine='zarr')
locations = locations.drop_vars(list(locations.data_vars.keys()))

In [3]:
y = locations.y.values
x = locations.x.values

In [4]:
date = "20240807"


In [5]:
# get DPT data
fs = s3fs.S3FileSystem(anon=True)
url = f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/DPT/500mb'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})
obs = xr.open_dataset(store, engine='zarr').load()
obs = obs.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
obs = obs.rename_vars({
    'DPT': 'obs'})

In [6]:
fs = s3fs.S3FileSystem(anon=True)
url =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/UGRD/500mb'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})

U = xr.open_dataset(store, engine='zarr').load()
U = U.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
U = U.rename_vars({
    'UGRD': 'U'})

In [7]:
fs = s3fs.S3FileSystem(anon=True)
url =  f's3://hrrrzarr/sfc/{date}/{date}_00z_fcst.zarr/500mb/VGRD/500mb'
store = zarr.storage.FsspecStore.from_url(url, storage_options={'anon': True})

V= xr.open_dataset(store, engine='zarr').load()
V = V.rename({
    'projection_x_coordinate': 'x',
    'projection_y_coordinate': 'y'
})
V = V.rename_vars({
    'VGRD': 'V'})

In [8]:
# merge
data = xr.merge([obs,U,V,locations])

In [9]:
Lambert_proj = ccrs.LambertConformal(central_longitude=262.5, central_latitude=38.5,
                                standard_parallels=[38.5,38.5],
                                globe=ccrs.Globe(semimajor_axis=6371229, 
                                                 semiminor_axis=6371229))

In [11]:
data_sub = data.isel(time = slice(0, 6, 1), y=slice(None, None, 25), x=slice(None, None, 25))

In [12]:
(l,d) = Lambert_proj.transform_point(-100,20,ccrs.Geodetic())
(r,u) = Lambert_proj.transform_point(-50,50,ccrs.Geodetic())
data_sub = data_sub.sel(x = slice(l,r), y = slice(d,u))

In [13]:
#plot.obs(data_sub, Lambert_proj)

In [14]:
n = data_sub.time.shape[0]
t = torch.linspace(0,1,n)
x = torch.tensor(data_sub.x.values, dtype=torch.float32)/1e6
y = torch.tensor(data_sub.y.values, dtype=torch.float32).flip(dims=[0])/1e6
Z = torch.tensor(data_sub.obs.values, dtype=torch.float32).flip(dims=[1]).reshape(-1) - 273.15
Z = Z/50


In [15]:
T,Y,X = torch.meshgrid(t,y,x,indexing='ij')
coords = torch.stack([T,X,Y], dim = -1)
TXY = coords.reshape(-1,3)

T = torch.unique(TXY[:,0])
XY =  torch.stack(torch.meshgrid(x,y,indexing='xy'), dim = -1).reshape(-1,2)

obs = Z.reshape(coords.shape[0:3])
Z_tensor = obs.flatten(1)

In [16]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)

mesh = ax.imshow(
    obs[0, :, :],
    cmap="viridis")

cbar = plt.colorbar(mesh, ax=ax, orientation='vertical', pad=0.02, aspect=16, shrink=0.8)
cbar.set_label('Dew Point Temperature (°C)', fontsize=12)

def update_contour(frame):
    global mesh 
    mesh.remove()
    mesh = ax.imshow(
        obs[frame,:, :],
        cmap="viridis")
        
    ax.set_title(f"observations at time point {frame}", fontsize=14)
    return mesh

ani = animation.FuncAnimation(
    fig,
    update_contour,
    frames=t.shape[0],  
    interval=100,         
    blit=False
)

plt.close(fig)

HTML(ani.to_jshtml())

In [30]:
import net, model
flow = net.Flow(L = 10)
vel = net.Vel(L = 8)
gp_flow = model.GP_FLOW(T, XY, Z_tensor, TXY, Z, flow, vel)


In [31]:
import gc
gc.collect()


5242

In [32]:
torch.cuda.empty_cache()

In [33]:
import optimize
with gpytorch.settings.fast_computations(log_prob=False):
    optimize.initial(gp_flow, num_epochs=75)

Epoch: 1 - Likelihood: 0.83
Epoch: 2 - Likelihood: 0.78
Epoch: 3 - Likelihood: 0.73
Epoch: 4 - Likelihood: 0.69
Epoch: 5 - Likelihood: 0.65
Epoch: 6 - Likelihood: 0.61
Epoch: 7 - Likelihood: 0.56
Epoch: 8 - Likelihood: 0.52
Epoch: 9 - Likelihood: 0.48
Epoch: 10 - Likelihood: 0.44
Epoch: 11 - Likelihood: 0.39
Epoch: 12 - Likelihood: 0.35
Epoch: 13 - Likelihood: 0.30
Epoch: 14 - Likelihood: 0.26
Epoch: 15 - Likelihood: 0.21
Epoch: 16 - Likelihood: 0.17
Epoch: 17 - Likelihood: 0.12
Epoch: 18 - Likelihood: 0.07
Epoch: 19 - Likelihood: 0.02
Epoch: 20 - Likelihood: -0.02
Epoch: 21 - Likelihood: -0.07
Epoch: 22 - Likelihood: -0.12
Epoch: 23 - Likelihood: -0.16
Epoch: 24 - Likelihood: -0.21
Epoch: 25 - Likelihood: -0.26
Epoch: 26 - Likelihood: -0.30
Epoch: 27 - Likelihood: -0.35
Epoch: 28 - Likelihood: -0.40
Epoch: 29 - Likelihood: -0.44
Epoch: 30 - Likelihood: -0.49
Epoch: 31 - Likelihood: -0.54
Epoch: 32 - Likelihood: -0.58
Epoch: 33 - Likelihood: -0.63
Epoch: 34 - Likelihood: -0.66
Epoch: 3

In [34]:
data_sub

In [35]:
import penalty
vels = penalty.v_hat(TXY, gp_flow.flow)
vels = gp_flow.vel(TXY)
vels = vels.reshape(6, 43, 39, 2).detach()


In [36]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
frame=0
s = 2
Q = ax.quiver(    
    coords[0,::s,::s,1],
    coords[0,::s,::s,2],
    vels[frame,::s,::s,0],
    vels[frame,::s,::s,1])


def update_quiver(frame):
    Q.set_UVC( vels[frame,::s,::s,0], vels[frame,::s,::s,1])
    return Q,

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=n
)

plt.close(fig)

HTML(ani.to_jshtml())

In [37]:
U = torch.tensor(data_sub.U.values).flip(dims=[1])
V = torch.tensor(data_sub.V.values).flip(dims=[1])


In [38]:
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(1, 1, 1)
frame=0
s = 2
'''Q1 = ax.quiver(    
    coords[0,::s,::s,1],
    coords[0,::s,::s,2],
    vels[frame,::s,::s,0],
    vels[frame,::s,::s,1], color = "red")'''

Q2 = ax.quiver(    
    coords[0,::s,::s,1],
    coords[0,::s,::s,2],
    U[frame,::s,::s],
    V[frame,::s,::s], color = "blue")



def update_quiver(frame):
    #Q1.set_UVC( vels[frame,::s,::s,0], vels[frame,::s,::s,1])
    Q2.set_UVC( U[frame,::s,::s], V[frame,::s,::s])
    ax.set_title(f"Wind Velocities at Hour {frame+1}")
    return Q2,

ani = animation.FuncAnimation(
    fig,
    update_quiver,
    frames=5
)

plt.close(fig)

HTML(ani.to_jshtml())