In [2]:
import warnings
import sys
import os
import pathlib
import torch 
import numpy as np
import random
import matplotlib.pyplot as plt
from IPython.display import HTML

warnings.filterwarnings('ignore')
sys.path.insert(0, str(pathlib.Path.cwd()/"src"))
%load_ext autoreload
%autoreload 2

SEED = 40
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [3]:
import loaders.hrrr
import fields.vector_field

In [None]:
date = "2024-09-18"
level = 500

# Load data
dsf = loaders.hrrr.discrete_scalar_field(date=date, level=level, hours=4, extent=(-85.5, -75.1, 30.5, 36.5))
dvf = loaders.hrrr.discrete_vector_field(date=date, level=level, hours=4, extent=(-85.5, -75.1, 30.5, 36.5))

# Create nested output folder: date/levelmb
folder_name = os.path.join("hrrr", date, f"{level}mb")
os.makedirs(folder_name, exist_ok=True)

# Plot and save scalar field at start and end frames
for frame in [0,dsf.coord_field.times.shape[0]-1]:
    fig = dsf.plot(frame=frame)
    fig.savefig(os.path.join(folder_name, f"dsf_frame{frame}.png"))
    plt.close(fig)

# Plot and save discrete vector field at center frame
fig = dvf.plot(factor = 12, frame=2)
fig.savefig(os.path.join(folder_name, "dvf_frame2.png"))
plt.close(fig)

✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Sep-18 00:00 UTC[92m F00[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Sep-18 00:00 UTC[92m F01[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Sep-18 00:00 UTC[92m F02[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Sep-18 00:00 UTC[92m F03[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Sep-18 00:00 UTC[92m F04[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Sep-18 00:00 UTC[92m F00[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m

In [45]:
fig = dsf.plot(gif = True)
HTML(fig.to_html5_video())

In [52]:
# Train continuous vector field
cvf = fields.vector_field.ContinuousVectorField()
cvf.train(dsf, epochs=100, nn=1, k=2, size=4000)

# Plot and save continuous vector field at center frame
fig = cvf.plot(dsf.coord_field, factor=12, frame=2)
fig.savefig(os.path.join(folder_name, "cvf_frame2.png"))
plt.close(fig)

# Extract learned parameters
sigma2 = cvf.sigma2
l0 = cvf.l0
l1 = cvf.l1
l2 = cvf.l2

# Compute RMS
RMS = 0
for frame in range(dvf.coord_field.times.size(0)):
    RMS += dvf.RMS(frame=frame) / dvf.coord_field.times.size(0)

# Compute RMSE
RMSE = 0
for frame in range(dvf.coord_field.times.size(0)):
    RMSE += cvf.RMSE(dvf, frame=frame) / dvf.coord_field.times.size(0)


[0.06931141763925552, 0.07256672531366348]
Epoch 1/100 — Avg NLL: 0.2881 — lengthscales: 5.99, 0.22, 0.22
[0.1364138275384903, 0.1424330323934555]
Epoch 2/100 — Avg NLL: 0.2772 — lengthscales: 5.98, 0.22, 0.22
[0.20111142098903656, 0.21353121101856232]
Epoch 3/100 — Avg NLL: 0.2579 — lengthscales: 5.97, 0.22, 0.22
[0.2668655514717102, 0.28283756971359253]
Epoch 4/100 — Avg NLL: 0.2541 — lengthscales: 5.96, 0.22, 0.22
[0.3273234963417053, 0.34811314940452576]
Epoch 5/100 — Avg NLL: 0.2398 — lengthscales: 5.95, 0.22, 0.22
[0.36972734332084656, 0.3974563479423523]
Epoch 6/100 — Avg NLL: 0.2296 — lengthscales: 5.95, 0.22, 0.21
[0.37146228551864624, 0.4014575779438019]
Epoch 7/100 — Avg NLL: 0.2095 — lengthscales: 5.94, 0.22, 0.21
[0.35463422536849976, 0.3816184401512146]
Epoch 8/100 — Avg NLL: 0.2014 — lengthscales: 5.93, 0.21, 0.21
[0.33349373936653137, 0.3534776568412781]
Epoch 9/100 — Avg NLL: 0.1938 — lengthscales: 5.92, 0.21, 0.21
[0.3193338215351105, 0.32850709557533264]
Epoch 10/100

In [53]:
RMS

9.102

In [54]:
for frame in range(dvf.coord_field.times.size(0)):
    print(cvf.RMSE(dvf, frame=frame))


2.57
2.38
2.3
2.39
2.58


In [33]:
import gc 
gc.collect()
torch.cuda.empty_cache()

In [19]:
fig = dsf.plot(gif = True)
HTML(fig.to_html5_video())