In [2]:
import warnings
import sys
import os
import pathlib
import torch 
import numpy as np
import random
import matplotlib.pyplot as plt
from IPython.display import HTML

warnings.filterwarnings('ignore')
sys.path.insert(0, str(pathlib.Path.cwd()/"src"))
%load_ext autoreload
%autoreload 2

SEED = 40
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

import loaders.hrrr
import fields.vector_field

In [3]:
import loaders.hrrr
import fields.vector_field

In [None]:
date = "2024-08-16"
#"2024-07-19"
level = 500

# Load data
dsf = loaders.hrrr.discrete_scalar_field(date=date, level=level, hours=4, extent=(-85.5, -75.1, 30.5, 36.5))
dvf = loaders.hrrr.discrete_vector_field(date=date, level=level, hours=4, extent=(-85.5, -75.1, 30.5, 36.5))

# Create nested output folder: date/levelmb
folder_name = os.path.join("hrrr", date, f"{level}mb")
os.makedirs(folder_name, exist_ok=True)

# Plot and save scalar field at start and end frames
for frame in [0,dsf.coord_field.times.shape[0]-1]:
    fig = dsf.plot(frame=frame)
    fig.savefig(os.path.join(folder_name, f"dsf_frame{frame}.png"))
    plt.close(fig)

# Plot and save discrete vector field at center frame
fig = dvf.plot(factor = 12, frame=2)
fig.savefig(os.path.join(folder_name, "dvf_frame2.png"))
plt.close(fig)

✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Aug-16 00:00 UTC[92m F00[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
👨🏻‍🏭 Created directory: [/home/yf297/data/hrrr/20240816]
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Aug-16 00:00 UTC[92m F01[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Aug-16 00:00 UTC[92m F02[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Aug-16 00:00 UTC[92m F03[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Aug-16 00:00 UTC[92m F04[0m ┊ [38;2;255;153;0m[3mGRIB2 @ aws[0m ┊ [38;2;255;153;0m[3mIDX @ aws[0m
✅ Found ┊ model=hrrr ┊ [3mproduct=sfc[0m ┊ [38;2;41;130;13m2024-Aug-16 00:0

KeyboardInterrupt: 

In [None]:
fig = dsf.plot(gif = True)
HTML(fig.to_html5_video())

In [None]:
# Train continuous vector field
cvf = fields.vector_field.ContinuousVectorField()
cvf.train(dsf, epochs=100, nn=1, k=2, size=3000)

# Plot and save continuous vector field at center frame
fig = cvf.plot(dsf.coord_field, factor=12, frame=2)
fig.savefig(os.path.join(folder_name, "cvf_frame2.png"))
plt.close(fig)

# Extract learned parameters
sigma2 = cvf.sigma2
l0 = cvf.l0
l1 = cvf.l1
l2 = cvf.l2



[0.7453863620758057, 0.8482828140258789, 0.7036393284797668]
Epoch 10/100 — Avg NLL: 0.2044 — lengthscales: 5.90, 0.24, 0.20
[0.847797691822052, 1.0838768482208252, 0.733717143535614]
Epoch 20/100 — Avg NLL: -0.0129 — lengthscales: 5.80, 0.26, 0.19
[0.8972573280334473, 1.2235316038131714, 0.7863544821739197]
Epoch 30/100 — Avg NLL: -0.1878 — lengthscales: 5.68, 0.28, 0.18
[0.9997557401657104, 1.337265968322754, 0.938344419002533]
Epoch 40/100 — Avg NLL: -0.2951 — lengthscales: 5.54, 0.30, 0.20
[1.2143735885620117, 1.481523871421814, 1.2226258516311646]
Epoch 50/100 — Avg NLL: -0.3814 — lengthscales: 5.39, 0.32, 0.22
[1.5760531425476074, 1.6688588857650757, 1.7367210388183594]
Epoch 60/100 — Avg NLL: -0.8132 — lengthscales: 5.22, 0.31, 0.23
[1.8891915082931519, 1.8407917022705078, 2.2663886547088623]
Epoch 70/100 — Avg NLL: -0.2111 — lengthscales: 5.05, 0.31, 0.25
[2.2439215183258057, 2.0607659816741943, 2.8594000339508057]
Epoch 80/100 — Avg NLL: -0.1423 — lengthscales: 4.88, 0.30, 0.2

In [None]:
# Compute RMS
RMS = 0
for frame in range(dvf.coord_field.times.size(0)):
    RMS += dvf.RMS(frame=frame) / dvf.coord_field.times.size(0)

RMS

11.984000000000002

In [57]:
# Compute RMSE
RMSE = 0
for frame in range(dvf.coord_field.times.size(0)):
    RMSE += cvf.RMSE(dvf, frame=frame) / dvf.coord_field.times.size(0)

In [58]:
RMSE

2.094

In [59]:
for frame in range(dvf.coord_field.times.size(0)):
    print(cvf.RMSE(dvf, frame=frame) )

2.5
2.08
1.96
1.91
2.02


In [61]:
l0 /3600

18.41327777777778