# Libraries

In [1]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
torch.cuda.empty_cache()
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"



# Load the model

In [11]:
model = AuroraSmall()

model.load_state_dict(torch.load('model/aurora.pth'))

<All keys matched successfully>

# Download South Africa Data

In [4]:
sa_area = [
    -21.0,  # North latitude (adjusted slightly north of South Africa's northernmost point)
    15.0,   # West longitude (adjusted slightly west of South Africa's westernmost point)
    -36.0,  # South latitude (adjusted slightly south of South Africa's southernmost point)
    34.0,   # East longitude (adjusted slightly east of South Africa's easternmost point)
]


In [5]:
sa_area = [
    -22.125,  # North latitude
    15.125,   # West longitude
    -37.875,  # South latitude
    34.875    # East longitude
]

In [12]:
# Data will be downloaded here.
download_path = Path("data_sa")

In [None]:



c = cdsapi.Client()

download_path = download_path.expanduser()
download_path.mkdir(parents=True, exist_ok=True)


# Download the static variables for South Africa.
if not (download_path / "static.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "geopotential",
                "land_sea_mask",
                "soil_type",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": "00:00",
            "format": "netcdf",
            "area": sa_area
        },
        str(download_path / "static.nc"),
    )
print("Static variables for South Africa downloaded!")


In [47]:

# Download the surface-level variables.
if not (download_path / "2023-01-01-surface-level.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "2m_temperature",
                "10m_u_component_of_wind",
                "10m_v_component_of_wind",
                "mean_sea_level_pressure",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": ["00:00", "06:00", "12:00", "18:00"],
            "format": "netcdf",
             "area": sa_area
        },
        str(download_path / "2023-01-01-surface-level.nc"),
    )
print("Surface-level variables downloaded!")

# Download the atmospheric variables.
if not (download_path / "2023-01-01-atmospheric.nc").exists():
    c.retrieve(
        "reanalysis-era5-pressure-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "temperature",
                "u_component_of_wind",
                "v_component_of_wind",
                "specific_humidity",
                "geopotential",
            ],
            "pressure_level": [
                "50",
                "100",
                "150",
                "200",
                "250",
                "300",
                "400",
                "500",
                "600",
                "700",
                "850",
                "925",
                "1000",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": ["00:00", "06:00", "12:00", "18:00"],
            "format": "netcdf",
             "area": sa_area
        },
        str(download_path / "2023-01-01-atmospheric.nc"),
    )
print("Atmospheric variables downloaded!")

2025-02-11 11:09:26,774 INFO Request ID is e9b503c7-63c5-46b8-9979-b2ba68c819b0


2025-02-11 11:09:28,618 INFO status has been updated to accepted
2025-02-11 11:09:37,375 INFO status has been updated to running
2025-02-11 11:09:50,530 INFO status has been updated to successful
                                                                                       

Surface-level variables downloaded!


2025-02-11 11:09:51,257 INFO Request ID is c33238d8-32ce-4616-86c8-0503f4edac47
2025-02-11 11:09:51,292 INFO status has been updated to accepted
2025-02-11 11:09:59,721 INFO status has been updated to running
2025-02-11 11:10:23,945 INFO status has been updated to successful
                                                                                         

Atmospheric variables downloaded!




# Preparing a batch

In [13]:
static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(download_path / "2023-01-01-surface-level.nc", engine="netcdf4")
atmos_vars_ds = xr.open_dataset(download_path / "2023-01-01-atmospheric.nc", engine="netcdf4")

In [49]:
# torch.from_numpy(surf_vars_ds["t2m"].values[[1 - 1, 1]]).shape
# torch.from_numpy(static_vars_ds["z"].values[0]).shape
# torch.from_numpy(atmos_vars_ds["t"].values[[1 - 1, 1]]).shape


torch.Size([2, 13, 64, 80])

In [7]:
# static_vars_ds.size()

In [6]:
# torch.from_numpy(surf_vars_ds["t2m"].values[[i - 1, i]][None]).size()

In [7]:
# static_vars_ds.size()

In [11]:
# p2d = (1,2, 1, 2)
# # static_vars_ds = torch.nn.functional.pad(torch.from_numpy(surf_vars_ds["t2m"].values[[i - 1, i]][None]), p2d, "constant", 0)

# i = 1  # Select this time index in the downloaded data.

# batch = Batch(
#     surf_vars={
#         # First select time points `i` and `i - 1`. Afterwards, `[None]` inserts a
#         # batch dimension of size one.
#         "2t": torch.nn.functional.pad(torch.from_numpy(surf_vars_ds["t2m"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "10u": torch.nn.functional.pad(torch.from_numpy(surf_vars_ds["u10"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "10v": torch.nn.functional.pad(torch.from_numpy(surf_vars_ds["v10"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "msl": torch.nn.functional.pad(torch.from_numpy(surf_vars_ds["msl"].values[[i - 1, i]][None]), p2d, "constant", 0),
#     },
#     static_vars={
#         # The static variables are constant, so we just get them for the first time.
#         "z": torch.nn.functional.pad(torch.from_numpy(static_vars_ds["z"].values[0]), p2d, "constant", 0),
#         "slt": torch.nn.functional.pad(torch.from_numpy(static_vars_ds["slt"].values[0]), p2d, "constant", 0),
#         "lsm": torch.nn.functional.pad(torch.from_numpy(static_vars_ds["lsm"].values[0]), p2d, "constant", 0),
#     },
#     atmos_vars={
#         "t": torch.nn.functional.pad(torch.from_numpy(atmos_vars_ds["t"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "u": torch.nn.functional.pad(torch.from_numpy(atmos_vars_ds["u"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "v": torch.nn.functional.pad(torch.from_numpy(atmos_vars_ds["v"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "q": torch.nn.functional.pad(torch.from_numpy(atmos_vars_ds["q"].values[[i - 1, i]][None]), p2d, "constant", 0),
#         "z": torch.nn.functional.pad(torch.from_numpy(atmos_vars_ds["z"].values[[i - 1, i]][None]), p2d, "constant", 0),
#     },
#     metadata=Metadata(
#         lat=torch.from_numpy(np.append(surf_vars_ds.latitude.values, np.array([-36.25, -36.5, -36.75]))),
#         lon=torch.from_numpy(np.append(surf_vars_ds.longitude.values, np.array([34.25, 34.5, 34.75]))),
#         # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
#         # `datetime.datetime`s. Note that this needs to be a tuple of length one:
#         # one value for every batch element.
#         time=(surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[i],),
#         atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
#     ),
# )

In [14]:

i = 1  # Select this time index in the downloaded data.

batch = Batch(
    surf_vars={
        # First select time points `i` and `i - 1`. Afterwards, `[None]` inserts a
        # batch dimension of size one.
        "2t": torch.from_numpy(surf_vars_ds["t2m"].values[[i - 1, i]][None]),
        "10u": torch.from_numpy(surf_vars_ds["u10"].values[[i - 1, i]][None]),
        "10v": torch.from_numpy(surf_vars_ds["v10"].values[[i - 1, i]][None]),
        "msl": torch.from_numpy(surf_vars_ds["msl"].values[[i - 1, i]][None]),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time.
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": torch.from_numpy(atmos_vars_ds["t"].values[[i - 1, i]][None]),
        "u": torch.from_numpy(atmos_vars_ds["u"].values[[i - 1, i]][None]),
        "v": torch.from_numpy(atmos_vars_ds["v"].values[[i - 1, i]][None]),
        "q": torch.from_numpy(atmos_vars_ds["q"].values[[i - 1, i]][None]),
        "z": torch.from_numpy(atmos_vars_ds["z"].values[[i - 1, i]][None]),
    },
    metadata=Metadata(
        lat=torch.from_numpy(surf_vars_ds.latitude.values),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element.
        time=(surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values)
    )
)

In [9]:
# torch.nn.functional.pad(torch.from_numpy(atmos_vars_ds["t"].values[[i - 1, i]][None]), p2d, "constant", 0).shape

In [10]:
# len(surf_vars_ds.latitude.values)

In [11]:
# surf_vars_ds.longitude.values

In [12]:
# np.append(surf_vars_ds.longitude.values, np.array([34.25, 34.5, 34.75]))

In [13]:
# np.append(surf_vars_ds.latitude.values, np.array([-36.25, -36.5, -36.75]))

In [14]:
# surf_vars_ds.latitude.values

In [15]:
# batch.atmos_vars['t'].shape

In [16]:
# 49%(49)

# Running the model

In [14]:
# batch.surf_vars["2t"][:, 2, :, :]

In [15]:
model.eval()
model = model.to("cuda")
# model.patch_size = 5
with torch.inference_mode():
    preds = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]

model = model.to("cpu")

In [8]:
batch.surf_vars["2t"].shape

torch.Size([1, 2, 64, 80])

In [9]:
preds[1].surf_vars["2t"].squeeze().shape

torch.Size([64, 80])

In [18]:
# batch.metadata.time

# Functions

In [20]:
def compute_rmse(predict_data, actual_data, var_type:str,
                     pred_var_name:str, actual_var_name:str
                     , atmos_level_idx=0):


    for i in range(2):
        pred = predict_data[i]
        if var_type=="surface":
            prediction = pred.surf_vars[pred_var_name][0, 0].numpy()
            actual = actual_data[actual_var_name][1 + i].values
            rmse = root_mean_squared_error(actual.squeeze(), prediction.squeeze() )
            print(f"RMSE: {rmse}")
        elif var_type=="atmosphere":
            prediction = pred.atmos_vars[actual_var_name].squeeze()[atmos_level_idx,:,:].numpy().squeeze()
            actual = actual_data[actual_var_name][i,atmos_level_idx,:,:].values
            # print(prediction.shape, actual.shape)
            rmse = root_mean_squared_error(actual.squeeze(), prediction.squeeze() )
            print(f"RMSE: {rmse}")
        

In [19]:
# model.patch_size

# Plot prediction

## Surface variables

### Two-meter temperature in K: 2t

In [11]:
fig, ax = plt.subplots(2, 2, figsize=(10, 6.5), dpi=500)

for i in range(ax.shape[0]):
    pred = preds[i]
    prediction = pred.surf_vars["2t"][0, 0].numpy()
    actual = surf_vars_ds["t2m"][1 + i].values
    rmse = root_mean_squared_error(actual.squeeze(), prediction.squeeze() )
    print(f"RMSE: {rmse}")

    ax[i, 0].imshow(pred.surf_vars["2t"][0, 0].numpy() - 273.15, vmin=-50, vmax=50)
    ax[i, 0].set_ylabel(str(pred.metadata.time[0]))
    if i == 0:
        ax[i, 0].set_title("Aurora Prediction")
    ax[i, 0].set_xticks([])
    ax[i, 0].set_yticks([])

    ax[i, 1].imshow(surf_vars_ds["t2m"][1 + i].values - 273.15, vmin=-50, vmax=50)
    if i == 0:
        ax[i, 1].set_title("ERA5")
    ax[i, 1].set_xticks([])
    ax[i, 1].set_yticks([])
plt.tight_layout()
plt.savefig("aurora-prediction-sa/aurora-prediction-t2m-sa.pdf")
plt.savefig("aurora-prediction-sa/aurora-prediction-t2m-sa.png")

### Ten-meter eastward wind speed in m/s :U10

In [33]:
def plot_predictions(predict_data, actual_data, var_type:str,
                     pred_var_name:str, actual_var_name:str, 
                     nrows:int=2, ncols:int=2,
                     figsize=(10, 6.5), atmos_level_idx=0, 
                     save_path="aurora-prediction-sa"):


    fig, ax = plt.subplots(nrows, ncols, figsize=figsize, dpi=500)

    for i in range(ax.shape[0]):
        pred = predict_data[i]
        if var_type=="surface":
            ax[i, 0].imshow(pred.surf_vars[pred_var_name][0, 0].numpy())
            actual = actual_data[actual_var_name][1 + i].values
            prediction = pred.surf_vars[pred_var_name][0, 0].numpy()
            print(f"RMSE: {root_mean_squared_error(actual.flatten(), prediction.flatten()) }")
            
        elif var_type=="atmosphere":
            ax[i, 0].imshow(pred.atmos_vars[actual_var_name].squeeze()[atmos_level_idx,:,:].numpy())
            actual = actual_data[actual_var_name][i,atmos_level_idx,:,:].values
            prediction = pred.atmos_vars[actual_var_name].squeeze()[atmos_level_idx,:,:]
            # print(actual.shape, prediction.shape)
            print(f"RMSE: {root_mean_squared_error(actual.flatten(), prediction.flatten()) }")
        elif var_type=="static":
            ax[i, 0].imshow(pred.static_vars[pred_var_name].numpy())
            
        ax[i, 0].set_ylabel(str(pred.metadata.time[0]))
        if i == 0:
            ax[i, 0].set_title("Aurora Prediction")
        ax[i, 0].set_xticks([])
        ax[i, 0].set_yticks([])
        if var_type=="atmosphere":
            ax[i, 1].imshow(actual_data[actual_var_name][i,atmos_level_idx,:,:].values)
        elif var_type=="static":
            ax[i, 1].imshow(actual_data[actual_var_name].squeeze().values)
        else:
            ax[i, 1].imshow(actual_data[actual_var_name][1 + i].values)
        if i == 0:
            ax[i, 1].set_title("ERA5")
        ax[i, 1].set_xticks([])
        ax[i, 1].set_yticks([])

    plt.tight_layout()#### 925 hPa
    plt.savefig(f"{save_path}/aurora-prediction-{actual_var_name}-sa.pdf")
    plt.savefig(f"{save_path}/aurora-prediction-{actual_var_name}-sa.png")
    plt.show()

In [None]:
plot_predictions(predict_data=preds, actual_data=surf_vars_ds,
                 var_type="surface",pred_var_name="10u",
                 actual_var_name="u10")

###  Ten-meter southward wind speed in m/s: V10

In [None]:
plot_predictions(predict_data=preds, actual_data=surf_vars_ds,
                 var_type="surface",pred_var_name="10v",
                 actual_var_name="v10")

### Mean sea-level pressure in Pa :msl

In [36]:
plot_predictions(predict_data=preds, actual_data=surf_vars_ds,
                 var_type="surface",pred_var_name="msl",
                 actual_var_name="msl")

## Atmosphere

### Temperature in K : t

In [37]:
# atmos_vars_ds["t"][0, 0,:,:].shape

#### 50 hPa

In [None]:
plot_predictions(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t", atmos_level_idx=0)

In [21]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=0)

RMSE: 4.578186988830566
RMSE: 3.858746290206909


#### 100 hPa

In [None]:
plot_predictions(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t", atmos_level_idx=1)

In [22]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=1)

RMSE: 4.224157810211182
RMSE: 3.837404251098633


#### 150 hPa

In [35]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=2)

RMSE: 3.513018846511841
RMSE: 3.897298812866211


#### 200 hPa

In [36]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=3)

RMSE: 0.8073528409004211
RMSE: 1.0384626388549805


#### 250 hPa

In [37]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=4)

RMSE: 0.6138967275619507
RMSE: 1.0199878215789795


#### 300 hPa

In [38]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=5)

RMSE: 0.970721423625946
RMSE: 1.0191086530685425


#### 400 hPa

In [39]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=6)

RMSE: 1.0418033599853516
RMSE: 0.6263337135314941


#### 500 hPa

In [40]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=7)

RMSE: 1.070984125137329
RMSE: 0.7618333101272583


#### 600hPa

In [41]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=8)

RMSE: 0.8289688229560852
RMSE: 0.844855785369873


#### 700 hPa

In [42]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=9)

RMSE: 0.5736856460571289
RMSE: 0.8279510736465454


#### 850 hPa

In [43]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=10)

RMSE: 1.3679838180541992
RMSE: 1.0744160413742065


#### 925 hPa

In [44]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=11)

RMSE: 1.2360966205596924
RMSE: 1.0463292598724365


#### 1000 hPa

In [45]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="t",
                 actual_var_name="t",  atmos_level_idx=12)

RMSE: 0.7319607734680176
RMSE: 0.9294496774673462


### Eastward wind speed in m/s u

In [None]:
plot_predictions(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u")

#### level 50 hPa

In [47]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=0)

RMSE: 2.026820421218872
RMSE: 2.3301734924316406


#### level 100 hPa

In [48]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=1)

RMSE: 2.6880998611450195
RMSE: 2.7293221950531006


#### level 150 hPa

In [49]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=2)

RMSE: 4.278347969055176
RMSE: 3.268280506134033


#### level 200 hPa

In [50]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=3)

RMSE: 2.7386679649353027
RMSE: 3.1816513538360596


#### level 250 hPa

In [51]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=4)

RMSE: 2.2701919078826904
RMSE: 3.1286909580230713


#### level 300 hPa

In [52]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=5)

RMSE: 3.122309923171997
RMSE: 2.145679473876953


#### level 400 hPa

In [53]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=6)

RMSE: 3.3605270385742188
RMSE: 2.5124969482421875


#### level 500 hPa

In [54]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=7)

RMSE: 6.871110439300537
RMSE: 5.696530342102051


#### level 600 hPa

In [55]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=8)

RMSE: 9.259191513061523
RMSE: 7.717942714691162


#### level 700 hPa

In [56]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=9)

RMSE: 10.849495887756348
RMSE: 7.0191168785095215


#### level 850 hPa

In [57]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=10)

RMSE: 7.683502197265625
RMSE: 7.577439785003662


#### level 925 hPa

In [58]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=11)

RMSE: 4.688336372375488
RMSE: 3.153499126434326


#### level 1000 hPa

In [59]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="u",
                 actual_var_name="u",  atmos_level_idx=12)

RMSE: 2.4588820934295654
RMSE: 3.00223970413208


### Southward wind speed in m/s :v

In [None]:
plot_predictions(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v")

#### level 50 hPa

In [61]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=0)

RMSE: 2.2030136585235596
RMSE: 3.0530760288238525


#### level 100 hPa

In [62]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=1)

RMSE: 2.530858278274536
RMSE: 3.167807102203369


#### level 150 hPa

In [63]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=2)

RMSE: 3.695732831954956
RMSE: 4.1963629722595215


#### level 200 hPa

In [64]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=3)

RMSE: 2.9125165939331055
RMSE: 1.9408111572265625


#### level 250 hPa

In [65]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=4)

RMSE: 3.3813419342041016
RMSE: 2.9111289978027344


#### level 300 hPa

In [66]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=5)

RMSE: 2.3951327800750732
RMSE: 3.9143214225769043


#### level 400 hPa

In [67]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=6)

RMSE: 3.322324275970459
RMSE: 4.003962516784668


#### level 500 hPa

In [68]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=7)

RMSE: 5.899149417877197
RMSE: 5.367013454437256


#### level 600 hPa

In [69]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=8)

RMSE: 6.873134613037109
RMSE: 6.3164191246032715


#### level 700 hPa

In [70]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=9)

RMSE: 5.314736366271973
RMSE: 5.466500282287598


#### level 850 hPa

In [71]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=10)

RMSE: 6.778872013092041
RMSE: 6.178730010986328


#### level 925 hPa

In [72]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=11)

RMSE: 3.4188194274902344
RMSE: 6.639578342437744


#### level 1000 hPa

In [73]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="v",
                 actual_var_name="v",  atmos_level_idx=12)

RMSE: 4.417582035064697
RMSE: 3.7341277599334717


### Specific humidity in kg / kg: q

In [None]:
plot_predictions(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q")

#### level 50 hPa

In [75]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=0)

RMSE: 0.00132470834068954
RMSE: 0.0018503069877624512


#### level 100 hPa

In [76]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=1)

RMSE: 0.001391159021295607
RMSE: 0.0018242687219753861


#### level 150 hPa

In [77]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=2)

RMSE: 0.0018014039378613234
RMSE: 0.0018360201502218843


#### level 200 hPa

In [78]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=3)

RMSE: 0.0012217375915497541
RMSE: 0.0015226168325170875


#### level 250 hPa

In [79]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=4)

RMSE: 0.0008543329313397408
RMSE: 0.0010457830503582954


#### level 300 hPa

In [80]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=5)

RMSE: 0.0007092472515068948
RMSE: 0.0007938714697957039


#### level 400 hPa

In [81]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=6)

RMSE: 0.0003647203557193279
RMSE: 0.00033323420211672783


#### level 500 hPa

In [82]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=7)

RMSE: 7.754356192890555e-05
RMSE: 8.30535645945929e-05


#### level 600 hPa

In [83]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=8)

RMSE: 2.7524869437911548e-05
RMSE: 3.936170833185315e-05


#### level 700 hPa

In [84]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=9)

RMSE: 6.695508091070224e-06
RMSE: 9.438476809009444e-06


#### level 850 hPa

In [85]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=10)

RMSE: 1.4007107438374078e-06
RMSE: 9.343611964141019e-07


#### level 925 hPa

In [86]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=11)

RMSE: 8.862902234341163e-08
RMSE: 7.743758345668539e-08


#### level 1000 hPa

In [87]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="q",
                 actual_var_name="q",  atmos_level_idx=12)

RMSE: 4.560165578482156e-08
RMSE: 2.137005949975901e-08


### Geopotential in m^2 / s^2 : z

In [42]:
plot_predictions(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z")

#### level 50 hPa

In [88]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=0)

RMSE: 181.72238159179688
RMSE: 227.6491241455078


#### level 100 hPa

In [89]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=1)

RMSE: 136.6709747314453
RMSE: 187.3612060546875


#### level 150 hPa

In [90]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=2)

RMSE: 74.1631851196289
RMSE: 112.76835632324219


#### level 200 hPa

In [91]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=3)

RMSE: 54.5391960144043
RMSE: 49.204734802246094


#### level 250 hPa

In [92]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=4)

RMSE: 61.86578369140625
RMSE: 50.085777282714844


#### level 300 hPa

In [93]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=5)

RMSE: 83.88148498535156
RMSE: 74.79078674316406


#### level 400 hPa

In [94]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=6)

RMSE: 121.8720932006836
RMSE: 105.0750503540039


#### level 500 hPa

In [95]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=7)

RMSE: 180.86172485351562
RMSE: 128.46400451660156


#### level 600 hPa

In [96]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=8)

RMSE: 216.04013061523438
RMSE: 143.4396514892578


#### level 700 hPa

In [97]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=9)

RMSE: 236.67269897460938
RMSE: 165.9544219970703


#### level 850 hPa

In [98]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=10)

RMSE: 182.22210693359375
RMSE: 136.3874053955078


#### level 925 hPa

In [99]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=11)

RMSE: 77.35247802734375
RMSE: 65.18022155761719


#### level 1000 hPa

In [100]:
compute_rmse(predict_data=preds, actual_data=atmos_vars_ds,
                 var_type="atmosphere",pred_var_name="z",
                 actual_var_name="z",  atmos_level_idx=12)

RMSE: 96.03746795654297
RMSE: 73.2583999633789


: 

In [43]:
# preds[0].atmos_vars["t"].shape

In [None]:
# preds[0].atmos_vars["t"].squeeze()[0,:,:].shape

In [47]:
# preds[0].static_vars["z"].shape

In [None]:
# static_vars_ds["z"].shape

###  Land-sea mask: lsm

In [None]:
plot_predictions(predict_data=preds, actual_data=static_vars_ds,
                 var_type="static",pred_var_name="lsm",
                 actual_var_name="lsm")

### Surface-level geopotential in m^2 / s^2:  z

In [50]:
plot_predictions(predict_data=preds, actual_data=static_vars_ds,
                 var_type="static",pred_var_name="z",
                 actual_var_name="z")

### Soil type: slt

In [51]:
plot_predictions(predict_data=preds, actual_data=static_vars_ds,
                 var_type="static",pred_var_name="slt",
                 actual_var_name="slt")

In [52]:
model

Aurora(
  (encoder): Perceiver3DEncoder(
    (surf_mlp): MLP(
      (net): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (surf_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (pos_embed): Linear(in_features=256, out_features=256, bias=True)
    (scale_embed): Linear(in_features=256, out_features=256, bias=True)
    (lead_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (absolute_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (atmos_levels_embed): Linear(in_features=256, out_features=256, bias=True)
    (surf_token_embeds): LevelPatchEmbed(
      (weights): ParameterDict(
          (10u): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
          (10v): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
    