# library

In [15]:
from pathlib import Path
import cdsapi
import xarray as xr
import torch
from aurora import Batch, Metadata, AuroraSmall, rollout
import numpy as np

## Download data

In [None]:
from pathlib import Path

import cdsapi

# Data will be downloaded here.
download_path = Path("data")

c = cdsapi.Client()

download_path = download_path.expanduser()
download_path.mkdir(parents=True, exist_ok=True)

# Download the static variables.
if not (download_path / "static.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "geopotential",
                "land_sea_mask",
                "soil_type",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": "00:00",
            "format": "netcdf",
        },
        
        bbox_WSEN=(16.125441,-34.872482,33.769484,-21.703449),
        region_name='south_africa',
        dest_path=str(download_path / "static.nc"),
        time_steps=range(24)
    )
print("Static variables downloaded!")

# Download the surface-level variables.
if not (download_path / "2023-01-01-surface-level.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "2m_temperature",
                "10m_u_component_of_wind",
                "10m_v_component_of_wind",
                "mean_sea_level_pressure",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": ["00:00", "06:00", "12:00", "18:00"],
            "format": "netcdf",
        },
        str(download_path / "2023-01-01-surface-level.nc"),
    )
print("Surface-level variables downloaded!")

# Download the atmospheric variables.
if not (download_path / "2023-01-01-atmospheric.nc").exists():
    c.retrieve(
        "reanalysis-era5-pressure-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "temperature",
                "u_component_of_wind",
                "v_component_of_wind",
                "specific_humidity",
                "geopotential",
            ],
            "pressure_level": [
                "50",
                "100",
                "150",
                "200",
                "250",
                "300",
                "400",
                "500",
                "600",
                "700",
                "850",
                "925",
                "1000",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": ["00:00", "06:00", "12:00", "18:00"],
            "format": "netcdf",
        },
        str(download_path / "2023-01-01-atmospheric.nc"),
    )
print("Atmospheric variables downloaded!")

In [14]:


# Define South Africa's geographical bounds
south_africa_bounds = {
    "north": -22.0,  # Northernmost latitude
    "south": -35.0,  # Southernmost latitude
    "west": 16.0,    # Westernmost longitude
    "east": 33.0     # Easternmost longitude
}
6.125441,-34.872482,33.769484,-21.703449
# Data will be downloaded here.
download_path = Path("data")
download_path.mkdir(parents=True, exist_ok=True)

c = cdsapi.Client()

# Download static variables
c.retrieve(
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "variable": [
            "geopotential",
            "land_sea_mask",
            "soil_type",
        ],
        "year": "2023",
        "month": "01",
        "day": "01",
        "time": "00:00",
        "format": "netcdf",
        "area": [
            south_africa_bounds["north"],
            south_africa_bounds["west"],
            south_africa_bounds["south"],
            south_africa_bounds["east"]
        ]
    },
    str(download_path / "static-south-africa.nc"),
)
print("Static variables downloaded for South Africa!")

# Download surface-level variables
c.retrieve(
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "variable": [
            "2m_temperature",
            "10m_u_component_of_wind",
            "10m_v_component_of_wind",
            "mean_sea_level_pressure",
        ],
        "year": "2023",
        "month": "01",
        "day": "01",
        "time": ["00:00", "06:00", "12:00", "18:00"],
        "format": "netcdf",
        "area": [
            south_africa_bounds["north"],
            south_africa_bounds["west"],
            south_africa_bounds["south"],
            south_africa_bounds["east"]
        ]
    },
    str(download_path / "surface-south-africa.nc"),
)
print("Surface-level variables downloaded for South Africa!")

# Download atmospheric variables
c.retrieve(
    "reanalysis-era5-pressure-levels",
    {
        "product_type": "reanalysis",
        "variable": [
            "temperature",
            "u_component_of_wind",
            "v_component_of_wind",
            "specific_humidity",
            "geopotential",
        ],
        "pressure_level": [
            "50", "100", "150", "200", "250", "300", "400", "500", "600", "700", "850", "925", "1000"
        ],
        "year": "2023",
        "month": "01",
        "day": "01",
        "time": ["00:00", "06:00", "12:00", "18:00"],
        "format": "netcdf",
        "area": [
            south_africa_bounds["north"],
            south_africa_bounds["west"],
            south_africa_bounds["south"],
            south_africa_bounds["east"]
        ]
    },
    str(download_path / "atmospheric-south-africa.nc"),
)
print("Atmospheric variables downloaded for South Africa!")

# Load and filter datasets
static_vars_ds = xr.open_dataset(download_path / "static-south-africa.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(download_path / "surface-south-africa.nc", engine="netcdf4")
atmos_vars_ds = xr.open_dataset(download_path / "atmospheric-south-africa.nc", engine="netcdf4")

# Ensure dimensions match model's expected patch size
def pad_to_nearest_multiple(arr, multiple):
    shape = arr.shape
    pad_dims = [(0, 0)] * len(shape)
    for dim in [-2, -1]:  # Latitude and Longitude dimensions
        pad_size = (multiple - shape[dim] % multiple) % multiple
        pad_dims[dim] = (pad_size // 2, pad_size - pad_size // 2)  # Even padding
    return np.pad(arr, pad_dims, mode='constant', constant_values=np.nan)

surf_vars_ds["t2m"] = xr.DataArray(
    pad_to_nearest_multiple(surf_vars_ds["t2m"].values, 16),
    dims=surf_vars_ds["t2m"].dims,
    coords=surf_vars_ds["t2m"].coords
)
atmos_vars_ds["t"] = xr.DataArray(
    pad_to_nearest_multiple(atmos_vars_ds["t"].values, 16),
    dims=atmos_vars_ds["t"].dims,
    coords=atmos_vars_ds["t"].coords
)

# Prepare batch for Aurora model
i = 1  # Select time index
batch = Batch(
    surf_vars={
        "2t": torch.from_numpy(surf_vars_ds["t2m"].values[[i - 1, i]][None]),
        "10u": torch.from_numpy(surf_vars_ds["u10"].values[[i - 1, i]][None]),
        "10v": torch.from_numpy(surf_vars_ds["v10"].values[[i - 1, i]][None]),
        "msl": torch.from_numpy(surf_vars_ds["msl"].values[[i - 1, i]][None]),
    },
    static_vars={
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": torch.from_numpy(atmos_vars_ds["t"].values[[i - 1, i]][None]),
        "u": torch.from_numpy(atmos_vars_ds["u"].values[[i - 1, i]][None]),
        "v": torch.from_numpy(atmos_vars_ds["v"].values[[i - 1, i]][None]),
        "q": torch.from_numpy(atmos_vars_ds["q"].values[[i - 1, i]][None]),
        "z": torch.from_numpy(atmos_vars_ds["z"].values[[i - 1, i]][None]),
    },
    metadata=Metadata(
        lat=torch.from_numpy(surf_vars_ds.latitude.values),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        time=(surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
    ),
)


model = AuroraSmall(use_lora=False)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
model.eval()
with torch.inference_mode():
    preds = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]

model = model.to("cpu")


# In[17]:


import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 2, figsize=(12, 6.5))

for i in range(ax.shape[0]):
    pred = preds[i]

    ax[i, 0].imshow(pred.surf_vars["2t"][0, 0].numpy() - 273.15, vmin=-50, vmax=50)
    ax[i, 0].set_ylabel(str(pred.metadata.time[0]))
    if i == 0:
        ax[i, 0].set_title("Aurora Prediction")
    ax[i, 0].set_xticks([])
    ax[i, 0].set_yticks([])

    ax[i, 1].imshow(surf_vars_ds["t2m"][2 + i].values - 273.15, vmin=-50, vmax=50)
    if i == 0:
        ax[i, 1].set_title("ERA5")
    ax[i, 1].set_xticks([])
    ax[i, 1].set_yticks([])

plt.tight_layout()







2025-02-05 09:56:41,703 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.


2025-02-05 09:56:41,952 INFO Request ID is 72220770-ec5d-4a8c-8cb5-ee09a151572d
2025-02-05 09:56:41,999 INFO status has been updated to accepted
2025-02-05 09:56:46,909 INFO status has been updated to running
2025-02-05 09:56:50,351 INFO status has been updated to successful
                                                                                        

Static variables downloaded for South Africa!


2025-02-05 09:56:51,371 INFO Request ID is 60fc300e-0257-4415-9662-d384cbeb874b
2025-02-05 09:56:51,408 INFO status has been updated to accepted
2025-02-05 09:56:56,295 INFO status has been updated to running
2025-02-05 09:57:04,806 INFO status has been updated to successful
2025-02-05 09:57:05,347 INFO Request ID is c45a6168-70f6-4854-b046-5a51e570ecf6        


Surface-level variables downloaded for South Africa!


2025-02-05 09:57:05,383 INFO status has been updated to accepted
2025-02-05 09:57:13,709 INFO status has been updated to successful
                                                                                         

Atmospheric variables downloaded for South Africa!


ValueError: conflicting sizes for dimension 'latitude': length 64 on the data but length 53 on coordinate 'latitude'