In [8]:
import zipfile
from pathlib import Path

import cdsapi
from huggingface_hub import hf_hub_download

# Data will be downloaded here.
download_path = Path("~/downloads/cams")

download_path = download_path.expanduser()
download_path.mkdir(parents=True, exist_ok=True)

# Download the static variables from HuggingFace.
static_path = hf_hub_download(
    repo_id="microsoft/aurora",
    filename="aurora-0.4-air-pollution-static.pickle",
    local_dir="~/downloads/cams"
)
print("Static variables downloaded!")

# Download the surface-level variables.
if not (download_path / "2025-10-20_cams_24h_forecast_utah.nc.zip").exists():
    c = cdsapi.Client()
    c.retrieve(
        "cams-global-atmospheric-composition-forecasts",
        {
            "type": "forecast",
            "format": "netcdf_zip",
            "date": "2025-10-20",
            "time": ["00:00"],  # ONE forecast initialization
            "leadtime_hour": [str(h) for h in range(0, 25)],  # 0–24h forecast
            "area": [42.1, -112.0, 37.0, -107.0],  # North, West, South, East (Utah)
            "variable": [
                "10m_u_component_of_wind", "10m_v_component_of_wind",
                "2m_temperature", "mean_sea_level_pressure",
                "particulate_matter_1um", "particulate_matter_2.5um",
                "particulate_matter_10um",
                "total_column_carbon_monoxide", "total_column_nitrogen_monoxide",
                "total_column_nitrogen_dioxide", "total_column_ozone",
                "total_column_sulphur_dioxide",
                "u_component_of_wind", "v_component_of_wind",
                "temperature", "geopotential", "specific_humidity",
                "carbon_monoxide", "nitrogen_dioxide", "nitrogen_monoxide",
                "ozone", "sulphur_dioxide",
            ],
            "pressure_level": [
                "50","100","150","200","250","300","400",
                "500","600","700","850","925","1000",
            ],
        },
        str(download_path / "2025-10-20_cams_24h_forecast_utah.nc.zip"),
    )
# Unpack the ZIP. It should contain the surface-level and atmospheric data in separate
# files.
if not (download_path / "2025-10-20-cams_24h_forecast_utah-surface-level.nc").exists():
    with zipfile.ZipFile(download_path / "2025-10-20_cams_24h_forecast_utah.nc.zip", "r") as zf, open(
        download_path / "2025-10-20-cams_24h_forecast_utah-surface-level.nc", "wb"
    ) as f:
        f.write(zf.read("data_sfc.nc"))
if not (download_path / "2025-10-20-cams_24h_forecast_utah-atmospheric.nc").exists():
    with zipfile.ZipFile(download_path / "2025-10-20_cams_24h_forecast_utah.nc.zip", "r") as zf, open(
        download_path / "2025-10-20-cams_24h_forecast_utah-atmospheric.nc", "wb"
    ) as f:
        f.write(zf.read("data_plev.nc"))
print("Surface-level and atmospheric variables downloaded!")

Static variables downloaded!
Surface-level and atmospheric variables downloaded!


In [17]:
import gc
import pickle
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import torch
import xarray as xr

from aurora import Batch, Metadata, AuroraAirPollution

warnings.filterwarnings("ignore", category=DeprecationWarning)

# ------------------ paths ------------------
download_path = Path("~/downloads/cams").expanduser()
surf_path = download_path / "2025-10-20-cams_24h_forecast_utah-surface-level.nc"
atmo_path = download_path / "2025-10-20-cams_24h_forecast_utah-atmospheric.nc"
static_path = download_path / "aurora-0.4-air-pollution-static.pickle"

out_dir = Path("./predictions").expanduser()
out_dir.mkdir(parents=True, exist_ok=True)
csv_path = out_dir / "aurora_hourly_utah_2025-10-20.csv"

# ------------------ helpers ------------------
def normalize_lat_lon(ds, lat_name="latitude", lon_name="longitude"):
    # lon -> [0,360), lon increasing; lat decreasing (north->south)
    if lon_name not in ds.coords or lat_name not in ds.coords:
        raise RuntimeError(f"Dataset missing coords '{lat_name}' and/or '{lon_name}'.")
    lon_wrapped = (ds[lon_name] % 360)
    ds = ds.assign_coords({lon_name: lon_wrapped})
    ds = ds.sortby(lon_name, ascending=True)
    lat = ds[lat_name]
    if lat.size >= 2 and float(lat[0]) < float(lat[-1]):
        ds = ds.sortby(lat_name, ascending=False)
    return ds

def grid_spacing_1d(arr_1d, default=0.5):
    if arr_1d.size >= 2:
        return abs(float(arr_1d[-1]) - float(arr_1d[-2]))
    return default

def pad_to_patch_multiple_1dcoords(ds, patch_size, lat_name="latitude", lon_name="longitude"):
    """
    Pad by edge-replication so H and W are multiples of patch_size and >= patch_size.
    Adjust 1-D lat/lon coordinates consistently.
    """
    H = ds.sizes[lat_name]
    W = ds.sizes[lon_name]
    target_H = max(patch_size, int(np.ceil(H / patch_size)) * patch_size)
    target_W = max(patch_size, int(np.ceil(W / patch_size)) * patch_size)
    pad_h = max(0, target_H - H)
    pad_w = max(0, target_W - W)
    if pad_h == 0 and pad_w == 0:
        return ds

    lat0 = ds[lat_name].values
    lon0 = ds[lon_name].values
    dlat = grid_spacing_1d(lat0)
    dlon = grid_spacing_1d(lon0)

    # pad data
    ds = ds.pad({lat_name: (0, pad_h), lon_name: (0, pad_w)}, mode="edge")

    # extend coords
    if pad_h > 0:
        lat_tail = np.array([lat0[-1] - dlat * (i + 1) for i in range(pad_h)], dtype=np.float64)
        new_lat = np.concatenate([lat0, lat_tail])
    else:
        new_lat = lat0

    if pad_w > 0:
        lon_tail = np.array([lon0[-1] + dlon * (i + 1) for i in range(pad_w)], dtype=np.float64)
        lon_tail = np.mod(lon_tail, 360.0)
        new_lon = np.concatenate([lon0, lon_tail])
    else:
        new_lon = lon0

    ds = ds.assign_coords({lat_name: (lat_name, new_lat), lon_name: (lon_name, new_lon)})
    ds = ds.sortby(lon_name, ascending=True)
    ds = ds.sortby(lat_name, ascending=False)
    return ds

def resize_nn(arr2d: np.ndarray, out_h: int, out_w: int) -> np.ndarray:
    """Nearest-neighbour resize of a 2D array to (out_h, out_w) without external deps."""
    in_h, in_w = arr2d.shape
    y_idx = np.floor(np.linspace(0, in_h - 1, out_h)).astype(int)
    x_idx = np.floor(np.linspace(0, in_w - 1, out_w)).astype(int)
    return arr2d[y_idx][:, x_idx]

def build_static_tile_resized(static_vars_full: dict, target_shape: tuple[int, int]) -> dict:
    """
    Resize each 2D static layer to target_shape via nearest neighbour.
    Ignores non-2D entries. Returns dict of np.float32 arrays (H, W).
    """
    Ht, Wt = target_shape
    out = {}
    for k, v in static_vars_full.items():
        arr = np.array(v)
        if arr.ndim != 2:
            continue
        out[k] = resize_nn(arr, Ht, Wt).astype(np.float32)
    if not out:
        raise RuntimeError("No 2D static layers found in the pickle.")
    return out

def get_pred_surf(pred, keys):
    if isinstance(keys, (list, tuple)):
        for k in keys:
            if k in pred.surf_vars:
                return pred.surf_vars[k][0, 0].cpu().numpy()
        return None
    else:
        k = keys
        if k in pred.surf_vars:
            return pred.surf_vars[k][0, 0].cpu().numpy()
        return None

# ------------------ load static ------------------
with open(static_path, "rb") as f:
    static_vars_full = pickle.load(f)

# ------------------ open CAMS ------------------
surf_all = xr.open_dataset(surf_path, engine="netcdf4", decode_timedelta=True)
atmo_all = xr.open_dataset(atmo_path, engine="netcdf4", decode_timedelta=True)

# Normalize coords
surf_all = normalize_lat_lon(surf_all, "latitude", "longitude")
atmo_all = normalize_lat_lon(atmo_all, "latitude", "longitude")

# ------------------ model ------------------
model = AuroraAirPollution()
model.load_checkpoint("microsoft/aurora", "aurora-0.4-air-pollution.ckpt")
model.eval()

ps = getattr(model, "patch_size", 16)
print("patch_size:", ps)

# Pad to multiples of patch size (and ≥ patch size)
surf_all = pad_to_patch_multiple_1dcoords(surf_all, ps, "latitude", "longitude")
atmo_all = pad_to_patch_multiple_1dcoords(atmo_all, ps, "latitude", "longitude")

# Lead-time coordinate
lead_name = "forecast_period" if "forecast_period" in surf_all.coords else ("step" if "step" in surf_all.coords else None)
if lead_name is None:
    raise RuntimeError("Could not find forecast lead coordinate ('forecast_period' or 'step').")
num_periods = int(surf_all.sizes[lead_name])
print(f"Lead coord: {lead_name} | periods: {num_periods}")

# ----- FINAL GRID SHAPE (1-D coords) -----
lat_1d = atmo_all.latitude.values  # shape (H,)
lon_1d = atmo_all.longitude.values # shape (W,)
assert lat_1d.ndim == 1 and lon_1d.ndim == 1
assert np.all((lon_1d >= 0) & (lon_1d < 360)), "Aurora expects lon in [0, 360)."

H, W = lat_1d.shape[0], lon_1d.shape[0]
print(f"Final grid size (H x W): {H} x {W}")

# Build 2-D lat/lon grids for Metadata + CSV
lat_2d, lon_2d = np.meshgrid(lat_1d, lon_1d, indexing="ij")  # (H,W) each

# Resize static layers ONCE to (H,W)
static_tile = build_static_tile_resized(static_vars_full, (H, W))
static_tile_torch = {k: torch.from_numpy(v) for k, v in static_tile.items()}

# ------------------ inference loop (T=2 history) ------------------
rows = []

# We need a previous slice and a current slice -> start from period=1
for period in range(1, num_periods):
    surf_prev  = surf_all.isel({lead_name: period - 1})
    surf_curr  = surf_all.isel({lead_name: period})
    atmos_prev = atmo_all.isel({lead_name: period - 1})
    atmos_curr = atmo_all.isel({lead_name: period})

    # times (prev, curr)
    t_prev = np.asarray(atmos_prev.valid_time.values, dtype="datetime64[s]").item()
    t_curr = np.asarray(atmos_curr.valid_time.values, dtype="datetime64[s]").item()

    # Stack T=2 along the time axis -> final shape (B=1, T=2, H, W)
    def T2(x_prev, x_curr, var):
        a = x_prev[var].values
        b = x_curr[var].values
        return torch.from_numpy(np.stack([a, b], axis=0)[None])

    surf_vars = {
        "2t":    T2(surf_prev,  surf_curr,  "t2m"),
        "10u":   T2(surf_prev,  surf_curr,  "u10"),
        "10v":   T2(surf_prev,  surf_curr,  "v10"),
        "msl":   T2(surf_prev,  surf_curr,  "msl"),
        "pm1":   T2(surf_prev,  surf_curr,  "pm1"),
        "pm2p5": T2(surf_prev,  surf_curr,  "pm2p5"),
        "pm10":  T2(surf_prev,  surf_curr,  "pm10"),
        "tcco":  T2(surf_prev,  surf_curr,  "tcco"),
        "tc_no": T2(surf_prev,  surf_curr,  "tc_no"),
        "tcno2": T2(surf_prev,  surf_curr,  "tcno2"),
        "gtco3": T2(surf_prev,  surf_curr,  "gtco3"),
        "tcso2": T2(surf_prev,  surf_curr,  "tcso2"),
    }

    atmos_vars = {
        "t":   T2(atmos_prev, atmos_curr, "t"),
        "u":   T2(atmos_prev, atmos_curr, "u"),
        "v":   T2(atmos_prev, atmos_curr, "v"),
        "q":   T2(atmos_prev, atmos_curr, "q"),
        "z":   T2(atmos_prev, atmos_curr, "z"),
        "co":  T2(atmos_prev, atmos_curr, "co"),
        "no":  T2(atmos_prev, atmos_curr, "no"),
        "no2": T2(atmos_prev, atmos_curr, "no2"),
        "go3": T2(atmos_prev, atmos_curr, "go3"),
        "so2": T2(atmos_prev, atmos_curr, "so2"),
    }

    batch = Batch(
        surf_vars=surf_vars,
        static_vars=static_tile_torch,  # (H,W) each; Batch will broadcast as needed
        atmos_vars=atmos_vars,
        metadata=Metadata(
            lat=torch.from_numpy(lat_2d),   # (H,W)
            lon=torch.from_numpy(lon_2d),   # (H,W)
            time=(t_prev, t_curr),          # T=2 history
            atmos_levels=tuple(int(level) for level in atmos_curr.pressure_level.values),
        ),
    )

    with torch.inference_mode():
        pred = model(batch)

    # Take the **current** step surface predictions (index 0, current history index is 1 internally)
    def take_curr(arr):
        # pred.surf_vars[...] has shape (B, T_out, H, W); for single-step it's (1,1,H,W)
        return arr[0, 0].cpu().numpy()

    pm1  = take_curr(pred.surf_vars["pm1"])
    pm25 = take_curr(pred.surf_vars["pm2p5"])
    pm10 = take_curr(pred.surf_vars["pm10"])
    co   = take_curr(pred.surf_vars["co"])
    no   = take_curr(pred.surf_vars["no"])
    no2  = take_curr(pred.surf_vars["no2"])
    o3   = take_curr(pred.surf_vars.get("o3", pred.surf_vars.get("go3")))
    so2  = take_curr(pred.surf_vars["so2"])

    # Convert PM kg/m^3 -> µg/m^3
    pm1  = pm1  * 1e9
    pm25 = pm25 * 1e9
    pm10 = pm10 * 1e9

    # Flatten -> rows (use our 2-D grids)
    size = H * W
    row_df = pd.DataFrame({
        "timestamp": np.repeat(np.datetime64(t_curr), size),
        "lat": lat_2d.ravel(),
        "lon": lon_2d.ravel(),
        "pm1_ugm3":   pm1.ravel(),
        "pm2p5_ugm3": pm25.ravel(),
        "pm10_ugm3":  pm10.ravel(),
        "co":         co.ravel(),
        "no":         no.ravel(),
        "no2":        no2.ravel(),
        "o3":         o3.ravel(),
        "so2":        so2.ravel(),
    })
    rows.append(row_df)

    # cleanup
    del batch, pred, row_df, pm1, pm25, pm10, co, no, no2, o3, so2
    gc.collect()

# ------------------ write CSV ------------------
df = pd.concat(rows, ignore_index=True)
df.to_csv(csv_path, index=False)
print(f"Saved: {csv_path}  (rows={len(df)})")


patch_size: 3
Lead coord: forecast_period | periods: 25
Final grid size (H x W): 15 x 15


ValueError: too many values to unpack (expected 6)

In [6]:
import zipfile
from pathlib import Path

import cdsapi
from huggingface_hub import hf_hub_download

# Data will be downloaded here.
download_path = Path("~/downloads/cams")

download_path = download_path.expanduser()
download_path.mkdir(parents=True, exist_ok=True)

# Download the static variables from HuggingFace.
static_path = hf_hub_download(
    repo_id="microsoft/aurora",
    filename="aurora-0.4-air-pollution-static.pickle",
    local_dir="~/downloads/cams"
)
print("Static variables downloaded!")

# Download the surface-level variables.
if not (download_path / "2025-10-20-cams.nc.zip").exists():
    c = cdsapi.Client()
    c.retrieve(
        "cams-global-atmospheric-composition-forecasts",
        {
            "type": "forecast",
            "leadtime_hour": "0",
            "variable": [
                # Meteorological surface-level variables:
                "10m_u_component_of_wind",
                "10m_v_component_of_wind",
                "2m_temperature",
                "mean_sea_level_pressure",
                # Pollution surface-level variables:
                "particulate_matter_1um",
                "particulate_matter_2.5um",
                "particulate_matter_10um",
                "total_column_carbon_monoxide",
                "total_column_nitrogen_monoxide",
                "total_column_nitrogen_dioxide",
                "total_column_ozone",
                "total_column_sulphur_dioxide",
                # Meteorological atmospheric variables:
                "u_component_of_wind",
                "v_component_of_wind",
                "temperature",
                "geopotential",
                "specific_humidity",
                # Pollution atmospheric variables:
                "carbon_monoxide",
                "nitrogen_dioxide",
                "nitrogen_monoxide",
                "ozone",
                "sulphur_dioxide",
            ],
            "pressure_level": [
                "50",
                "100",
                "150",
                "200",
                "250",
                "300",
                "400",
                "500",
                "600",
                "700",
                "850",
                "925",
                "1000",
            ],
            "date": "2025-10-20",
            "time": ["00:00", "12:00"],
            "format": "netcdf_zip",
        },
        str(download_path / "2025-10-20-cams.nc.zip"),
    )
# Unpack the ZIP. It should contain the surface-level and atmospheric data in separate
# files.
if not (download_path / "2025-10-20-cams-surface-level.nc").exists():
    with zipfile.ZipFile(download_path / "2025-10-20-cams.nc.zip", "r") as zf, open(
        download_path / "2025-10-20-cams-surface-level.nc", "wb"
    ) as f:
        f.write(zf.read("data_sfc.nc"))
if not (download_path / "2025-10-20-cams-atmospheric.nc").exists():
    with zipfile.ZipFile(download_path / "2025-10-20-cams.nc.zip", "r") as zf, open(
        download_path / "2025-10-20-cams-atmospheric.nc", "wb"
    ) as f:
        f.write(zf.read("data_plev.nc"))
print("Surface-level and atmospheric variables downloaded!")

Static variables downloaded!


2025-10-28 11:19:09,439 INFO Request ID is 3c5c3810-3da7-435d-b9ad-3fdef37fa322
2025-10-28 11:19:09,615 INFO status has been updated to accepted
2025-10-28 11:19:23,677 INFO status has been updated to running
2025-10-28 11:20:26,068 INFO status has been updated to successful
                                                                                         

Surface-level and atmospheric variables downloaded!


In [None]:
#####################################

# ALL DATA Create UTAH SLICE 
# PM2.5 from CAMS via Microsoft Aurora — Utah/SLC subset
# Robust single full-window inference with >= 8 * patch_size per dim
# - Ensures sizes are multiples of patch_size and lat↓ / lon↑
# - Returns index slices to cut static grids consistently
# - NO PLOT; CSV includes pm1, pm10, ozone, NO2, NO, CO, SO2 columns

#####################################

import gc, math, pickle
from pathlib import Path

import numpy as np
import torch
import xarray as xr
from huggingface_hub import hf_hub_download
import pandas as pd

from aurora import Batch, Metadata, AuroraAirPollution

# ------------------ config / paths ------------------
download_path = Path("~/downloads/cams").expanduser()
surf_path = download_path / "2025-10-20-cams-surface-level.nc"
atmo_path = download_path / "2025-10-20-cams-atmospheric.nc"
static_path = hf_hub_download("microsoft/aurora", "aurora-0.4-air-pollution-static.pickle")

out_dir = Path("./predictions")
out_dir.mkdir(parents=True, exist_ok=True)
csv_path = out_dir / "pollutant_prediction_utah.csv"

# ------------------ REGION BOUNDS (degrees) ------------------
# Use 0..360 longitudes 
UTAH_BBOX = dict(lat_min=37.0, lat_max=42.1, lon_min=245.9, lon_max=251.0)
SLC_BBOX  = dict(lat_min=40.4, lat_max=41.1, lon_min=247.7, lon_max=248.4)

# Choose region here:
BBOX = SLC_BBOX   # or UTAH_BBOX

# ------------------ helpers ------------------
def detect_lon_domain(ds_lon: np.ndarray) -> str:
    lon_min = float(np.nanmin(ds_lon)); lon_max = float(np.nanmax(ds_lon))
    return "0_360" if lon_min >= 0 and lon_max > 180 else "-180_180"

def to_dataset_lon(lon_vals, target_domain: str):
    """Map longitude(s) into the dataset's domain. Accepts scalar/array."""
    arr = np.asarray(lon_vals, dtype=float)
    if target_domain == "0_360":
        arr = arr % 360.0
        arr = np.where(arr < 0, arr + 360.0, arr)
    else:
        arr = ((arr + 180.0) % 360.0) - 180.0
    return float(arr) if np.ndim(lon_vals) == 0 else arr

def ensure_model_orientation(ds: xr.Dataset) -> xr.Dataset:
    """Ensure lon strictly increasing, lat strictly decreasing (Aurora requirement)."""
    out = ds
    if out.longitude.size >= 2 and not np.all(np.diff(out.longitude.values) > 0):
        out = out.isel(longitude=np.argsort(out.longitude.values))
    if out.latitude.size >= 2 and not np.all(np.diff(out.latitude.values) < 0):
        out = out.isel(latitude=slice(None, None, -1))
    if out.latitude.size >= 2:
        assert np.all(np.diff(out.latitude.values) < 0), "lat must be strictly decreasing"
    if out.longitude.size >= 2:
        assert np.all(np.diff(out.longitude.values) > 0), "lon must be strictly increasing"
    return out

def slice_bbox_value(ds: xr.Dataset, bbox: dict) -> xr.Dataset:
    """Value-based bbox slice (handles lon domain + any lat order)."""
    lat = ds.latitude; lon = ds.longitude
    lon_domain = detect_lon_domain(lon.values)
    lon_min_ds, lon_max_ds = to_dataset_lon([bbox["lon_min"], bbox["lon_max"]], lon_domain)
    lat_min, lat_max = bbox["lat_min"], bbox["lat_max"]

    if lat[0] > lat[-1]:
        lat_slice = slice(lat_max, lat_min)  # descending
    else:
        lat_slice = slice(lat_min, lat_max)  # ascending

    if lon_min_ds <= lon_max_ds:
        lon_slice = slice(lon_min_ds, lon_max_ds)
        out = ds.sel(latitude=lat_slice, longitude=lon_slice)
    else:
        left  = ds.sel(latitude=lat_slice, longitude=slice(lon_min_ds, float(lon.values.max())))
        right = ds.sel(latitude=lat_slice, longitude=slice(float(lon.values.min()), lon_max_ds))
        out = xr.concat([left, right], dim="longitude")

    if out.sizes.get("latitude", 0) == 0 or out.sizes.get("longitude", 0) == 0:
        raise RuntimeError("BBox slice returned empty selection. Check bounds & lon domain.")
    return out

def nearest_index(vec, value):
    return int(np.abs(vec - value).argmin())

def roundup_to_multiple(n, m):
    """Smallest multiple of m that is >= n."""
    return ((int(n) + m - 1) // m) * m

def expand_region_indices(base_lat, base_lon, center_lat_raw, center_lon_raw, target_h, target_w):
    """
    Compute (i_start, i_end, j_start, j_end) on the ORIGINAL grid
    around (center_lat_raw, center_lon_raw) for a target_h x target_w window.
    """
    latv = base_lat; lonv = base_lon
    dom = detect_lon_domain(lonv)
    center_lon = float(to_dataset_lon(center_lon_raw, dom))
    center_lat = float(center_lat_raw)

    i0 = nearest_index(latv, center_lat)
    j0 = nearest_index(lonv, center_lon)

    half_h = target_h // 2
    half_w = target_w // 2

    i_start = max(0, i0 - half_h)
    i_end   = min(latv.size, i0 + half_h + (target_h % 2 != 0))
    if i_end - i_start < target_h:
        deficit = target_h - (i_end - i_start)
        i_start = max(0, i_start - deficit//2)
        i_end   = min(latv.size, i_end + math.ceil(deficit/2))

    j_start = max(0, j0 - half_w)
    j_end   = min(lonv.size, j0 + half_w + (target_w % 2 != 0))
    if j_end - j_start < target_w:
        deficit = target_w - (j_end - j_start)
        j_start = max(0, j_start - deficit//2)
        j_end   = min(lonv.size, j_end + math.ceil(deficit/2))

    # clamp + ensure non-empty
    i_start = max(0, min(i_start, latv.size-1))
    i_end   = max(i_start+1, min(i_end,   latv.size))
    j_start = max(0, min(j_start, lonv.size-1))
    j_end   = max(j_start+1, min(j_end,   lonv.size))
    return i_start, i_end, j_start, j_end

def sizes_multiple_of_ps(ds: xr.Dataset, ps: int) -> bool:
    H = ds.sizes["latitude"]; W = ds.sizes["longitude"]
    return (H % ps == 0) and (W % ps == 0)

def crop_to_multiple_of_ps(ds: xr.Dataset, ps: int) -> xr.Dataset:
    """Crop dataset so H and W are multiples of patch size (keeps top-left corner)."""
    H = ds.sizes["latitude"]; W = ds.sizes["longitude"]
    Hc = (H // ps) * ps
    Wc = (W // ps) * ps
    return ds.isel(latitude=slice(0, Hc), longitude=slice(0, Wc))

def safe_time_tuple(ds):
    if "valid_time" in ds:
        vt = np.asarray(ds.valid_time.values)
        if vt.ndim == 0: vt = vt[None]
        return tuple(pd.to_datetime(vt).to_pydatetime())
    return tuple()

# ------------------ load static vars ------------------
with open(static_path, "rb") as f:
    static_vars_full = pickle.load(f)

# ------------------ open datasets ------------------
surf_base = xr.open_dataset(surf_path, engine="netcdf4", decode_timedelta=True)
atmo_base = xr.open_dataset(atmo_path, engine="netcdf4", decode_timedelta=True)
if "forecast_period" in surf_base.dims: surf_base = surf_base.isel(forecast_period=0)
if "forecast_period" in atmo_base.dims: atmo_base = atmo_base.isel(forecast_period=0)

# ------------------ model & patch size ------------------
model = AuroraAirPollution()
model.load_checkpoint()
model.eval()
ps = int(model.patch_size)
print("patch_size:", ps)

# Swin3D uses multiple spatial downsamplings; require >= 8 patches per dim
MIN_PATCHES_PER_DIM = 8
min_cells = MIN_PATCHES_PER_DIM * ps

# ------------------ prelim slice (diag) ------------------
surf_reg0 = slice_bbox_value(surf_base, BBOX)
atmo_reg0 = slice_bbox_value(atmo_base, BBOX)
H0, W0 = surf_reg0.sizes["latitude"], surf_reg0.sizes["longitude"]
print(f"Prelim slice size (HxW): {H0} x {W0}")

# ------------------ compute target sizes ------------------
target_H = max(min_cells, roundup_to_multiple(max(H0, min_cells), ps))
target_W = max(min_cells, roundup_to_multiple(max(W0, min_cells), ps))

# ------------------ build single window by indices (and keep slices) ----------
lat_center_raw = 0.5 * (BBOX["lat_min"] + BBOX["lat_max"])
lon_center_raw = 0.5 * (BBOX["lon_min"] + BBOX["lon_max"])

# index expansion on ORIGINAL grids (use surf_base as reference for both)
i_start, i_end, j_start, j_end = expand_region_indices(
    surf_base.latitude.values, surf_base.longitude.values,
    lat_center_raw, lon_center_raw,
    target_H, target_W
)

surf_win = surf_base.isel(latitude=slice(i_start, i_end), longitude=slice(j_start, j_end))
atmo_win = atmo_base.isel(latitude=slice(i_start, i_end), longitude=slice(j_start, j_end))

# enforce orientation
surf_win = ensure_model_orientation(surf_win)
atmo_win = ensure_model_orientation(atmo_win)

# crop to multiples of ps
surf_ds = crop_to_multiple_of_ps(surf_win, ps)
atmo_ds = crop_to_multiple_of_ps(atmo_win, ps)

H, W = surf_ds.sizes["latitude"], surf_ds.sizes["longitude"]
print(f"Window after crop (HxW): {H} x {W}")
assert H >= min_cells and W >= min_cells, f"Need >= {min_cells} cells per dim (got {H}x{W}, ps={ps})"
assert sizes_multiple_of_ps(surf_ds, ps) and sizes_multiple_of_ps(atmo_ds, ps), "Not multiples of patch size."

# ------------------ cut static grids to the same indices ------------------
def slice_static_like(i_start, i_end, j_start, j_end, orient_like: xr.Dataset, ps: int):
    cut = {}
    for k, arr in static_vars_full.items():
        if arr.ndim == 3:  # (C,H,W)
            arr_cut = arr[:, i_start:i_end, j_start:j_end]
        elif arr.ndim == 2:  # (H,W)
            arr_cut = arr[i_start:i_end, j_start:j_end]
        else:
            cut[k] = arr
            continue
        # mirror latitude flip done by ensure_model_orientation (if any)
        flipped_lat = not (np.all(np.diff(surf_win.latitude.values) < 0)) if surf_win.latitude.size >= 2 else False
        if flipped_lat:
            if arr_cut.ndim == 3:
                arr_cut = arr_cut[:, ::-1, :]
            else:
                arr_cut = arr_cut[::-1, :]
        # crop to multiples of ps
        if arr_cut.ndim == 3:
            Hc = (arr_cut.shape[1] // ps) * ps
            Wc = (arr_cut.shape[2] // ps) * ps
            arr_cut = arr_cut[:, :Hc, :Wc]
        elif arr_cut.ndim == 2:
            Hc = (arr_cut.shape[0] // ps) * ps
            Wc = (arr_cut.shape[1] // ps) * ps
            arr_cut = arr_cut[:Hc, :Wc]
        cut[k] = arr_cut
    return cut

static_vars_tile = slice_static_like(i_start, i_end, j_start, j_end, surf_ds, ps)

# ------------------ single full-window inference ------------------
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

surf_vars = {
    "2t":    torch.from_numpy(surf_ds["t2m" ].values[None]),
    "10u":   torch.from_numpy(surf_ds["u10" ].values[None]),
    "10v":   torch.from_numpy(surf_ds["v10" ].values[None]),
    "msl":   torch.from_numpy(surf_ds["msl" ].values[None]),
    "pm1":   torch.from_numpy(surf_ds["pm1" ].values[None]),
    "pm2p5": torch.from_numpy(surf_ds["pm2p5"].values[None]),
    "pm10":  torch.from_numpy(surf_ds["pm10"].values[None]),
    "tcco":  torch.from_numpy(surf_ds["tcco"].values[None]),
    "tc_no": torch.from_numpy(surf_ds["tc_no"].values[None]),
    "tcno2": torch.from_numpy(surf_ds["tcno2"].values[None]),
    "gtco3": torch.from_numpy(surf_ds["gtco3"].values[None]),
    "tcso2": torch.from_numpy(surf_ds["tcso2"].values[None]),
}
atmos_vars = {
    "t":   torch.from_numpy(atmo_ds["t" ].values[None]),
    "u":   torch.from_numpy(atmo_ds["u" ].values[None]),
    "v":   torch.from_numpy(atmo_ds["v" ].values[None]),
    "q":   torch.from_numpy(atmo_ds["q" ].values[None]),
    "z":   torch.from_numpy(atmo_ds["z"].values[None]),
    "co":  torch.from_numpy(atmo_ds["co"].values[None]),
    "no":  torch.from_numpy(atmo_ds["no"].values[None]),
    "no2": torch.from_numpy(atmo_ds["no2"].values[None]),
    "go3": torch.from_numpy(atmo_ds["go3"].values[None]),
    "so2": torch.from_numpy(atmo_ds["so2"].values[None]),
}

batch = Batch(
    surf_vars=surf_vars,
    static_vars={k: torch.from_numpy(v) for k, v in static_vars_tile.items()},
    atmos_vars=atmos_vars,
    metadata=Metadata(
        lat=torch.from_numpy(atmo_ds.latitude.values),   # strictly decreasing
        lon=torch.from_numpy(atmo_ds.longitude.values),  # strictly increasing
        time=safe_time_tuple(atmo_ds),
        atmos_levels=tuple(int(x) for x in atmo_ds.pressure_level.values),
    ),
)

with torch.inference_mode():
    pred = model(batch)

# PM2.5 prediction (kg/m^3 -> µg/m^3)
pm25_kgm3 = pred.surf_vars["pm2p5"][0, 0].cpu().numpy().astype(np.float32)
pm25_ugm3 = pm25_kgm3 / 1e-9

# ------------------ assemble CSV (no plotting) ------------------
# Consistent orientation for export: lat increasing for human readability
lat_vec = atmo_ds.latitude.values.copy()   # decreasing
lon_vec = atmo_ds.longitude.values.copy()  # increasing
lat_out = lat_vec.copy()
pm25_out = pm25_ugm3.copy()
flip_lat = lat_out[0] > lat_out[-1]
if flip_lat:
    lat_out = lat_out[::-1]
    pm25_out = pm25_out[::-1, :]

# Helper to extract + orient 2D fields like PM1/PM10 (surface) in µg/m^3
def surf_pm_as_ugm3(varname):
    arr = surf_ds[varname].values  # (time?, H, W) or (H,W)
    if arr.ndim == 3: arr = arr[0]
    arr = arr / 1e-9  # kg/m^3 -> µg/m^3
    if flip_lat: arr = arr[::-1, :]
    return arr

pm1_ugm3  = surf_pm_as_ugm3("pm1")
pm10_ugm3 = surf_pm_as_ugm3("pm10")

# Near-surface gases from the atmospheric cube:
# Pick the highest pressure level (closest to the surface)
plevs = atmo_ds.pressure_level.values
k_surface = int(np.argmax(plevs))  # highest pressure

def atmo_surface_field(name):
    v = atmo_ds[name].values  # (plev, H, W) or (time?,plev,H,W)
    if v.ndim == 4: v = v[0, k_surface]
    elif v.ndim == 3: v = v[k_surface]
    else: raise RuntimeError(f"Unexpected dims for {name}: {v.shape}")
    if flip_lat: v = v[::-1, :]
    return v

ozone = atmo_surface_field("go3")
nitrogen_dioxide = atmo_surface_field("no2")
nitrogen_monoxide = atmo_surface_field("no")
carbon_monoxide   = atmo_surface_field("co")
sulfur_dioxide    = atmo_surface_field("so2")

# Build 2D lon/lat grids
LON2D, LAT2D = np.meshgrid(lon_vec, lat_out)

df = pd.DataFrame({
    "latitude":           LAT2D.ravel(),
    "longitude":          LON2D.ravel(),
    "pm2p5_ugm3":         pm25_out.ravel(),
    "pm1_ugm3":           pm1_ugm3.ravel(),
    "pm10_ugm3":          pm10_ugm3.ravel(),
    "ozone":              ozone.ravel(),
    "nitrogen_dioxide":   nitrogen_dioxide.ravel(),
    "nitrogen_monoxide":  nitrogen_monoxide.ravel(),
    "carbon_monoxide":    carbon_monoxide.ravel(),
    "sulfur_dioxide":     sulfur_dioxide.ravel(),
})

# Timestamp (if present)
try:
    vt = np.asarray(atmo_ds.valid_time.values)
    if vt.ndim == 0: vt = vt[0:1]
    sel_time = pd.to_datetime(vt[0]).to_pydatetime()
except Exception:
    sel_time = None
if sel_time is not None:
    df.insert(0, "time", sel_time)

# Drop NaNs and save
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna().reset_index(drop=True)
df.to_csv(csv_path, index=False)
print(f"Saved CSV: {csv_path}")
print(f"Final working size (HxW): {pm25_kgm3.shape[0]} x {pm25_kgm3.shape[1]}  (ps={ps}, min required: {MIN_PATCHES_PER_DIM*ps})")
print(df.head(10))


patch_size: 3
Prelim slice size (HxW): 1 x 1
Window after crop (HxW): 24 x 24
Saved CSV: predictions/pollutant_prediction_utah.csv
Final working size (HxW): 24 x 24  (ps=3, min required: 24)
        time  latitude  longitude  pm2p5_ugm3  pm1_ugm3  pm10_ugm3  \
0 2025-10-20      36.4      243.2    6.144454  2.515953   4.962630   
1 2025-10-20      36.4      243.6    6.554677  3.670983   7.366880   
2 2025-10-20      36.4      244.0    7.013986  3.739067   7.631770   
3 2025-10-20      36.4      244.4    6.963577  3.815977   7.465333   
4 2025-10-20      36.4      244.8    7.344985  4.195292   7.300032   
5 2025-10-20      36.4      245.2    7.183073  3.435723   5.515376   
6 2025-10-20      36.4      245.6    6.622933  2.241485   3.401483   
7 2025-10-20      36.4      246.0    6.147038  1.801929   2.471069   
8 2025-10-20      36.4      246.4    5.931744  1.503060   2.060887   
9 2025-10-20      36.4      246.8    5.405423  1.307036   2.122733   

          ozone  nitrogen_dioxide  nit

In [16]:
#####################################

# UTAH DATA with LEAD TIME
# PM2.5 from CAMS via Microsoft Aurora — Utah/SLC subset
# Robust single full-window inference with >= 8 * patch_size per dim
# - Ensures sizes are multiples of patch_size and lat↓ / lon↑
# - Returns index slices to cut static grids consistently
# - NO PLOT; CSV includes pm1, pm10, ozone, NO2, NO, CO, SO2 columns

#####################################

import gc, math, pickle
from pathlib import Path

import numpy as np
import torch
import xarray as xr
from huggingface_hub import hf_hub_download
import pandas as pd

from aurora import Batch, Metadata, AuroraAirPollution

# ------------------ config / paths ------------------
download_path = Path("~/downloads/cams").expanduser()
surf_path = download_path / "2025-10-20-cams_24h_forecast_utah-surface-level.nc"
atmo_path = download_path / "2025-10-20-cams_24h_forecast_utah-atmospheric.nc"
static_path = hf_hub_download("microsoft/aurora", "aurora-0.4-air-pollution-static.pickle")

out_dir = Path("./predictions")
out_dir.mkdir(parents=True, exist_ok=True)
csv_path = out_dir / "pollutant_prediction_utah_hourly.csv"

# ------------------ REGION BOUNDS (degrees) ------------------
# Use 0..360 longitudes 
UTAH_BBOX = dict(lat_min=37.0, lat_max=42.1, lon_min=245.9, lon_max=251.0)
SLC_BBOX  = dict(lat_min=40.4, lat_max=41.1, lon_min=247.7, lon_max=248.4)

# Choose region here:
BBOX = SLC_BBOX   # or UTAH_BBOX

# ------------------ helpers ------------------
def detect_lon_domain(ds_lon: np.ndarray) -> str:
    lon_min = float(np.nanmin(ds_lon)); lon_max = float(np.nanmax(ds_lon))
    return "0_360" if lon_min >= 0 and lon_max > 180 else "-180_180"

def to_dataset_lon(lon_vals, target_domain: str):
    """Map longitude(s) into the dataset's domain. Accepts scalar/array."""
    arr = np.asarray(lon_vals, dtype=float)
    if target_domain == "0_360":
        arr = arr % 360.0
        arr = np.where(arr < 0, arr + 360.0, arr)
    else:
        arr = ((arr + 180.0) % 360.0) - 180.0
    return float(arr) if np.ndim(lon_vals) == 0 else arr

def ensure_model_orientation(ds: xr.Dataset) -> xr.Dataset:
    """Ensure lon strictly increasing, lat strictly decreasing (Aurora requirement)."""
    out = ds
    if out.longitude.size >= 2 and not np.all(np.diff(out.longitude.values) > 0):
        out = out.isel(longitude=np.argsort(out.longitude.values))
    if out.latitude.size >= 2 and not np.all(np.diff(out.latitude.values) < 0):
        out = out.isel(latitude=slice(None, None, -1))
    if out.latitude.size >= 2:
        assert np.all(np.diff(out.latitude.values) < 0), "lat must be strictly decreasing"
    if out.longitude.size >= 2:
        assert np.all(np.diff(out.longitude.values) > 0), "lon must be strictly increasing"
    return out

def slice_bbox_value(ds: xr.Dataset, bbox: dict) -> xr.Dataset:
    """Value-based bbox slice (handles lon domain + any lat order)."""
    lat = ds.latitude; lon = ds.longitude
    lon_domain = detect_lon_domain(lon.values)
    lon_min_ds, lon_max_ds = to_dataset_lon([bbox["lon_min"], bbox["lon_max"]], lon_domain)
    lat_min, lat_max = bbox["lat_min"], bbox["lat_max"]

    if lat[0] > lat[-1]:
        lat_slice = slice(lat_max, lat_min)  # descending
    else:
        lat_slice = slice(lat_min, lat_max)  # ascending

    if lon_min_ds <= lon_max_ds:
        lon_slice = slice(lon_min_ds, lon_max_ds)
        out = ds.sel(latitude=lat_slice, longitude=lon_slice)
    else:
        left  = ds.sel(latitude=lat_slice, longitude=slice(lon_min_ds, float(lon.values.max())))
        right = ds.sel(latitude=lat_slice, longitude=slice(float(lon.values.min()), lon_max_ds))
        out = xr.concat([left, right], dim="longitude")

    if out.sizes.get("latitude", 0) == 0 or out.sizes.get("longitude", 0) == 0:
        raise RuntimeError("BBox slice returned empty selection. Check bounds & lon domain.")
    return out

def nearest_index(vec, value):
    return int(np.abs(vec - value).argmin())

def roundup_to_multiple(n, m):
    """Smallest multiple of m that is >= n."""
    return ((int(n) + m - 1) // m) * m

def expand_region_indices(base_lat, base_lon, center_lat_raw, center_lon_raw, target_h, target_w):
    """
    Compute (i_start, i_end, j_start, j_end) on the ORIGINAL grid
    around (center_lat_raw, center_lon_raw) for a target_h x target_w window.
    """
    latv = base_lat; lonv = base_lon
    dom = detect_lon_domain(lonv)
    center_lon = float(to_dataset_lon(center_lon_raw, dom))
    center_lat = float(center_lat_raw)

    i0 = nearest_index(latv, center_lat)
    j0 = nearest_index(lonv, center_lon)

    half_h = target_h // 2
    half_w = target_w // 2

    i_start = max(0, i0 - half_h)
    i_end   = min(latv.size, i0 + half_h + (target_h % 2 != 0))
    if i_end - i_start < target_h:
        deficit = target_h - (i_end - i_start)
        i_start = max(0, i_start - deficit//2)
        i_end   = min(latv.size, i_end + math.ceil(deficit/2))

    j_start = max(0, j0 - half_w)
    j_end   = min(lonv.size, j0 + half_w + (target_w % 2 != 0))
    if j_end - j_start < target_w:
        deficit = target_w - (j_end - j_start)
        j_start = max(0, j_start - deficit//2)
        j_end   = min(lonv.size, j_end + math.ceil(deficit/2))

    # clamp + ensure non-empty
    i_start = max(0, min(i_start, latv.size-1))
    i_end   = max(i_start+1, min(i_end,   latv.size))
    j_start = max(0, min(j_start, lonv.size-1))
    j_end   = max(j_start+1, min(j_end,   lonv.size))
    return i_start, i_end, j_start, j_end

def sizes_multiple_of_ps(ds: xr.Dataset, ps: int) -> bool:
    H = ds.sizes["latitude"]; W = ds.sizes["longitude"]
    return (H % ps == 0) and (W % ps == 0)

def crop_to_multiple_of_ps(ds: xr.Dataset, ps: int) -> xr.Dataset:
    """Crop dataset so H and W are multiples of patch size (keeps top-left corner)."""
    H = ds.sizes["latitude"]; W = ds.sizes["longitude"]
    Hc = (H // ps) * ps
    Wc = (W // ps) * ps
    return ds.isel(latitude=slice(0, Hc), longitude=slice(0, Wc))

def safe_time_tuple(ds):
    if "valid_time" in ds:
        vt = np.asarray(ds.valid_time.values)
        if vt.ndim == 0: vt = vt[None]
        return tuple(pd.to_datetime(vt).to_pydatetime())
    return tuple()

# ------------------ load static vars ------------------
with open(static_path, "rb") as f:
    static_vars_full = pickle.load(f)

# ------------------ open datasets ------------------
surf_base = xr.open_dataset(surf_path, engine="netcdf4", decode_timedelta=True)
atmo_base = xr.open_dataset(atmo_path, engine="netcdf4", decode_timedelta=True)
if "forecast_period" in surf_base.dims: surf_base = surf_base.isel(forecast_period=0)
if "forecast_period" in atmo_base.dims: atmo_base = atmo_base.isel(forecast_period=0)

# ------------------ model & patch size ------------------
model = AuroraAirPollution()
model.load_checkpoint()
model.eval()
ps = int(model.patch_size)
print("patch_size:", ps)

# Swin3D uses multiple spatial downsamplings; require >= 8 patches per dim
MIN_PATCHES_PER_DIM = 8
min_cells = MIN_PATCHES_PER_DIM * ps

# ------------------ prelim slice (diag) ------------------
surf_reg0 = slice_bbox_value(surf_base, BBOX)
atmo_reg0 = slice_bbox_value(atmo_base, BBOX)
H0, W0 = surf_reg0.sizes["latitude"], surf_reg0.sizes["longitude"]
print(f"Prelim slice size (HxW): {H0} x {W0}")

# ------------------ compute target sizes ------------------
target_H = max(min_cells, roundup_to_multiple(max(H0, min_cells), ps))
target_W = max(min_cells, roundup_to_multiple(max(W0, min_cells), ps))

# ------------------ build single window by indices (and keep slices) ----------
lat_center_raw = 0.5 * (BBOX["lat_min"] + BBOX["lat_max"])
lon_center_raw = 0.5 * (BBOX["lon_min"] + BBOX["lon_max"])

# index expansion on ORIGINAL grids (use surf_base as reference for both)
i_start, i_end, j_start, j_end = expand_region_indices(
    surf_base.latitude.values, surf_base.longitude.values,
    lat_center_raw, lon_center_raw,
    target_H, target_W
)

surf_win = surf_base.isel(latitude=slice(i_start, i_end), longitude=slice(j_start, j_end))
atmo_win = atmo_base.isel(latitude=slice(i_start, i_end), longitude=slice(j_start, j_end))

# enforce orientation
surf_win = ensure_model_orientation(surf_win)
atmo_win = ensure_model_orientation(atmo_win)

# crop to multiples of ps
surf_ds = crop_to_multiple_of_ps(surf_win, ps)
atmo_ds = crop_to_multiple_of_ps(atmo_win, ps)

H, W = surf_ds.sizes["latitude"], surf_ds.sizes["longitude"]
print(f"Window after crop (HxW): {H} x {W}")
assert H >= min_cells and W >= min_cells, f"Need >= {min_cells} cells per dim (got {H}x{W}, ps={ps})"
assert sizes_multiple_of_ps(surf_ds, ps) and sizes_multiple_of_ps(atmo_ds, ps), "Not multiples of patch size."

# ------------------ cut static grids to the same indices ------------------
def slice_static_like(i_start, i_end, j_start, j_end, orient_like: xr.Dataset, ps: int):
    cut = {}
    for k, arr in static_vars_full.items():
        if arr.ndim == 3:  # (C,H,W)
            arr_cut = arr[:, i_start:i_end, j_start:j_end]
        elif arr.ndim == 2:  # (H,W)
            arr_cut = arr[i_start:i_end, j_start:j_end]
        else:
            cut[k] = arr
            continue
        # mirror latitude flip done by ensure_model_orientation (if any)
        flipped_lat = not (np.all(np.diff(surf_win.latitude.values) < 0)) if surf_win.latitude.size >= 2 else False
        if flipped_lat:
            if arr_cut.ndim == 3:
                arr_cut = arr_cut[:, ::-1, :]
            else:
                arr_cut = arr_cut[::-1, :]
        # crop to multiples of ps
        if arr_cut.ndim == 3:
            Hc = (arr_cut.shape[1] // ps) * ps
            Wc = (arr_cut.shape[2] // ps) * ps
            arr_cut = arr_cut[:, :Hc, :Wc]
        elif arr_cut.ndim == 2:
            Hc = (arr_cut.shape[0] // ps) * ps
            Wc = (arr_cut.shape[1] // ps) * ps
            arr_cut = arr_cut[:Hc, :Wc]
        cut[k] = arr_cut
    return cut

static_vars_tile = slice_static_like(i_start, i_end, j_start, j_end, surf_ds, ps)

# ------------------ single full-window inference ------------------
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

surf_vars = {
    "2t":    torch.from_numpy(surf_ds["t2m" ].values[None]),
    "10u":   torch.from_numpy(surf_ds["u10" ].values[None]),
    "10v":   torch.from_numpy(surf_ds["v10" ].values[None]),
    "msl":   torch.from_numpy(surf_ds["msl" ].values[None]),
    "pm1":   torch.from_numpy(surf_ds["pm1" ].values[None]),
    "pm2p5": torch.from_numpy(surf_ds["pm2p5"].values[None]),
    "pm10":  torch.from_numpy(surf_ds["pm10"].values[None]),
    "tcco":  torch.from_numpy(surf_ds["tcco"].values[None]),
    "tc_no": torch.from_numpy(surf_ds["tc_no"].values[None]),
    "tcno2": torch.from_numpy(surf_ds["tcno2"].values[None]),
    "gtco3": torch.from_numpy(surf_ds["gtco3"].values[None]),
    "tcso2": torch.from_numpy(surf_ds["tcso2"].values[None]),
}
atmos_vars = {
    "t":   torch.from_numpy(atmo_ds["t" ].values[None]),
    "u":   torch.from_numpy(atmo_ds["u" ].values[None]),
    "v":   torch.from_numpy(atmo_ds["v" ].values[None]),
    "q":   torch.from_numpy(atmo_ds["q" ].values[None]),
    "z":   torch.from_numpy(atmo_ds["z"].values[None]),
    "co":  torch.from_numpy(atmo_ds["co"].values[None]),
    "no":  torch.from_numpy(atmo_ds["no"].values[None]),
    "no2": torch.from_numpy(atmo_ds["no2"].values[None]),
    "go3": torch.from_numpy(atmo_ds["go3"].values[None]),
    "so2": torch.from_numpy(atmo_ds["so2"].values[None]),
}

batch = Batch(
    surf_vars=surf_vars,
    static_vars={k: torch.from_numpy(v) for k, v in static_vars_tile.items()},
    atmos_vars=atmos_vars,
    metadata=Metadata(
        lat=torch.from_numpy(atmo_ds.latitude.values),   # strictly decreasing
        lon=torch.from_numpy(atmo_ds.longitude.values),  # strictly increasing
        time=safe_time_tuple(atmo_ds),
        atmos_levels=tuple(int(x) for x in atmo_ds.pressure_level.values),
    ),
)

with torch.inference_mode():
    pred = model(batch)

# PM2.5 prediction (kg/m^3 -> µg/m^3)
pm25_kgm3 = pred.surf_vars["pm2p5"][0, 0].cpu().numpy().astype(np.float32)
pm25_ugm3 = pm25_kgm3 / 1e-9

# ------------------ assemble CSV (no plotting) ------------------
# Consistent orientation for export: lat increasing for human readability
lat_vec = atmo_ds.latitude.values.copy()   # decreasing
lon_vec = atmo_ds.longitude.values.copy()  # increasing
lat_out = lat_vec.copy()
pm25_out = pm25_ugm3.copy()
flip_lat = lat_out[0] > lat_out[-1]
if flip_lat:
    lat_out = lat_out[::-1]
    pm25_out = pm25_out[::-1, :]

# Helper to extract + orient 2D fields like PM1/PM10 (surface) in µg/m^3
def surf_pm_as_ugm3(varname):
    arr = surf_ds[varname].values  # (time?, H, W) or (H,W)
    if arr.ndim == 3: arr = arr[0]
    arr = arr / 1e-9  # kg/m^3 -> µg/m^3
    if flip_lat: arr = arr[::-1, :]
    return arr

pm1_ugm3  = surf_pm_as_ugm3("pm1")
pm10_ugm3 = surf_pm_as_ugm3("pm10")

# Near-surface gases from the atmospheric cube:
# Pick the highest pressure level (closest to the surface)
plevs = atmo_ds.pressure_level.values
k_surface = int(np.argmax(plevs))  # highest pressure

def atmo_surface_field(name):
    v = atmo_ds[name].values  # (plev, H, W) or (time?,plev,H,W)
    if v.ndim == 4: v = v[0, k_surface]
    elif v.ndim == 3: v = v[k_surface]
    else: raise RuntimeError(f"Unexpected dims for {name}: {v.shape}")
    if flip_lat: v = v[::-1, :]
    return v

ozone = atmo_surface_field("go3")
nitrogen_dioxide = atmo_surface_field("no2")
nitrogen_monoxide = atmo_surface_field("no")
carbon_monoxide   = atmo_surface_field("co")
sulfur_dioxide    = atmo_surface_field("so2")

# Build 2D lon/lat grids
LON2D, LAT2D = np.meshgrid(lon_vec, lat_out)

df = pd.DataFrame({
    "latitude":           LAT2D.ravel(),
    "longitude":          LON2D.ravel(),
    "pm2p5_ugm3":         pm25_out.ravel(),
    "pm1_ugm3":           pm1_ugm3.ravel(),
    "pm10_ugm3":          pm10_ugm3.ravel(),
    "ozone":              ozone.ravel(),
    "nitrogen_dioxide":   nitrogen_dioxide.ravel(),
    "nitrogen_monoxide":  nitrogen_monoxide.ravel(),
    "carbon_monoxide":    carbon_monoxide.ravel(),
    "sulfur_dioxide":     sulfur_dioxide.ravel(),
})

# Timestamp (if present)
try:
    vt = np.asarray(atmo_ds.valid_time.values)
    if vt.ndim == 0: vt = vt[0:1]
    sel_time = pd.to_datetime(vt[0]).to_pydatetime()
except Exception:
    sel_time = None
if sel_time is not None:
    df.insert(0, "time", sel_time)

# Drop NaNs and save
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna().reset_index(drop=True)
df.to_csv(csv_path, index=False)
print(f"Saved CSV: {csv_path}")
print(f"Final working size (HxW): {pm25_kgm3.shape[0]} x {pm25_kgm3.shape[1]}  (ps={ps}, min required: {MIN_PATCHES_PER_DIM*ps})")
print(df.head(10))


patch_size: 3
Prelim slice size (HxW): 2 x 1
Window after crop (HxW): 12 x 12


AssertionError: Need >= 24 cells per dim (got 12x12, ps=3)