In [None]:
# -*- coding: utf-8 -*-
"""
Real-time evaluation of CFSv2 precipitation initialization error
over HUC2=07 (Upper Mississippi) and HUC2=10 (Missouri) basins.

Requirements:
  python -m pip install xarray cfgrib==0.9.14.1 eccodes requests pandas numpy geopandas shapely pyproj rioxarray tqdm pytz

Notes:
- CFSv2 operational archive is a 7-day rolling window. Run this daily or more often.
- We read CFSv2 6-hr "flux" (flxf*.grb2) files, variable shortName='prate'.
- We build 6-hr accumulations from the rate and align with MRMS hourly QPE aggregated to 6-hr windows.
- Initialization error is reported by lead bin and by cycle/member.
"""

import os
import io
import re
import sys
import gzip
import json
import math
import time
import ftplib
import pytz
import queue
import shutil
import zipfile
import logging
import datetime as dt
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import xarray as xr
import rioxarray  # for raster masking
import geopandas as gpd
import requests
from shapely.geometry import shape, mapping
from shapely.ops import unary_union
from tqdm import tqdm
import tempfile
import os

def _load_cfgrib(grib_bytes: bytes, fkeys: dict | None = None) -> xr.Dataset:
    with tempfile.NamedTemporaryFile(suffix=".grib2", delete=False) as tf:
        tf.write(grib_bytes)
        tf.flush()
        tmp_path = tf.name
    try:
        backend_kwargs = {"indexpath": ":memory:"}
        if fkeys:
            backend_kwargs["filter_by_keys"] = fkeys
        with xr.open_dataset(tmp_path, engine="cfgrib", backend_kwargs=backend_kwargs) as ds:
            ds = ds.load()
        if not ds.data_vars:
            raise RuntimeError(f"cfgrib open produced empty dataset for filter={fkeys}")
        return ds
    finally:
        try: os.remove(tmp_path)
        except Exception: pass

# -----------------------
# Configuration
# -----------------------
NOMADS_BASE = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/cfs/prod"
# We will pull "flxf" 6-hourly flux files from: {NOMADS_BASE}/cfs.YYYYMMDD/CC/6hrly_grib_0M/
# where M = 1..4 is ensemble member, CC in {00,06,12,18}

MRMS_BASE = "https://mrms.ncep.noaa.gov/2D/MultiSensor_QPE_01H_Pass2"

# HUC2 polygons via USGS WBD MapServer layer 1 (HUC2)
WBD_HUC2_ENDPOINT = "https://hydro.nationalmap.gov/arcgis/rest/services/wbd/MapServer/1/query"

# Spatial resampling method for basin averaging
CFS_RESAMPLING = "nearest"    # options: "nearest", "bilinear" (xarray's "linear" via rioxarray)
MRMS_RESAMPLING = "nearest"

# Timezone for outputs
TZ = "America/Chicago"

# Max forecast days to evaluate (NWM LRF uses first 30 days)
MAX_FCST_DAYS = 30

# Batches
DOWNLOAD_DIR = Path("./data_rt")
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

# -----------------------
# Utility helpers
# -----------------------

def latest_cfs_cycle() -> Tuple[str, str]:
    """
    Discover the latest available CFSv2 cycle and its date folder by scraping the index.
    Returns (yyyyMMdd, cycle='00'|'06'|'12'|'18').
    """
    idx = requests.get(NOMADS_BASE + "/").text
    # find cfs.YYYYMMDD/ entries
    dates = sorted(re.findall(r'href="cfs\.(\d{8})/', idx))
    if not dates:
        raise RuntimeError("No cfs.YYYYMMDD directories found on NOMADS.")
    last_day = dates[-1]
    day_idx = requests.get(f"{NOMADS_BASE}/cfs.{last_day}/").text
    cycles = sorted(re.findall(r'href="(00|06|12|18)/"', day_idx))
    if not cycles:
        raise RuntimeError(f"No cycles found for cfs.{last_day}")
    # Use the latest present cycle
    cyc = cycles[-1]
    return last_day, cyc

def list_cfs_member_dir(yyyymmdd: str, cycle: str, member: int) -> str:
    # member folders are 6hrly_grib_01..04
    return f"{NOMADS_BASE}/cfs.{yyyymmdd}/{cycle}/6hrly_grib_{member:02d}/"

def get_huc2_polygon(huc2: str) -> gpd.GeoSeries:
    """
    Fetch HUC2 polygon from USGS WBD service as GeoJSON.
    """
    params = {
        "where": f"HUC2='{huc2}'",
        "outFields": "*",
        "f": "geojson",
        "outSR": 4326
    }
    r = requests.get(WBD_HUC2_ENDPOINT, params=params, timeout=60)
    r.raise_for_status()
    gj = r.json()
    feats = [shape(f["geometry"]) for f in gj.get("features", []) if f.get("geometry")]
    if not feats:
        raise RuntimeError(f"No polygon returned for HUC2={huc2}")
    union = unary_union(feats)
    return gpd.GeoSeries([union], crs="EPSG:4326")




def open_cfs_precip_6h(grib_bytes: bytes) -> xr.DataArray:
    """
    Open a CFS flxf GRIB and return 6-hour precipitation accumulation as a DataArray named 'cfs_prcp_6h' in mm.
    Handles either:
      - PRATE (rate), stepType=avg  -> multiply by 6*3600
      - APCP/TP (accumulation), stepType=accum -> convert units if needed
    """
    if not (len(grib_bytes) >= 4 and grib_bytes[:4] == b"GRIB"):
        raise ValueError("Bytes are not a GRIB file (missing GRIB header).")

    # Try PRATE avg at surface first (most common in flxf)
    try:
        ds = _load_cfgrib(grib_bytes, {
            "shortName": "prate",
            "typeOfLevel": "surface",
            "stepType": "avg"
        })
        da = ds["prate"]
        sixh = da * (6 * 3600.0)  # kg m-2 s-1 -> mm per 6h
        sixh = sixh.rename("cfs_prcp_6h")
        sixh.attrs["units"] = "mm/6h"
        return sixh
    except Exception:
        pass

    # Try PRATE without strict keys
    try:
        ds = _load_cfgrib(grib_bytes, {"shortName": "prate"})
        da = ds["prate"]
        # If stepType is avg or mean over 6h, convert to mm/6h.
        # Some files encode exact averaging interval via 'step' coordinate.
        factor = 6 * 3600.0
        if "step" in da.coords:
            # Use exact step seconds if present
            try:
                step_sec = pd.to_timedelta(da["step"].values).astype("timedelta64[s]").astype(int)
                # Many files have a single step; if array, take first
                if np.ndim(step_sec) > 0:
                    step_sec = int(np.array(step_sec).ravel()[0])
                factor = float(step_sec)
            except Exception:
                pass
        sixh = da * factor
        sixh = sixh.rename("cfs_prcp_6h")
        sixh.attrs["units"] = "mm/6h"
        return sixh
    except Exception:
        pass

    # Try accumulated precip variants (APCP). Depending on tables, shortName may be 'apcp' or 'tp'.
    for sn in ("apcp", "tp"):
        try:
            # accumulated over period (stepType=accum), surface
            ds = _load_cfgrib(grib_bytes, {
                "shortName": sn,
                "typeOfLevel": "surface",
                "stepType": "accum"
            })
        except Exception:
            # try with only shortName
            try:
                ds = _load_cfgrib(grib_bytes, {"shortName": sn})
            except Exception:
                ds = None
        if ds is not None and len(ds.data_vars) > 0:
            varname = list(ds.data_vars)[0] if sn not in ds.data_vars else sn
            da = ds[varname]
            # Units can be 'm' for ECMWF-like 'tp'; convert to mm
            units = da.attrs.get("units", "").lower()
            if units in ("m", "meter", "metre", "meters", "metres"):
                da_mm = da * 1000.0
            else:
                # Many NCEP APCP fields arrive already as mm
                da_mm = da
            sixh = da_mm.rename("cfs_prcp_6h")
            sixh.attrs["units"] = "mm/6h"
            return sixh

    # Final fall-back: open without filters so we can report what's inside
    try:
        anyds = _load_cfgrib(grib_bytes, None)
        vars_found = list(anyds.data_vars)
        raise RuntimeError(f"Could not locate precip in GRIB. Variables present: {vars_found}")
    except Exception as e:
        raise RuntimeError(
            "Could not open GRIB with cfgrib at all. Verify eccodes and cfgrib installations."
        ) from e

def fetch_url(url: str, retry=5, sleep=5) -> bytes:
    for k in range(retry):
        r = requests.get(url, timeout=120)
        if r.ok:
            b = r.content
            # Accept GRIB2 or gzip depending on URL suffix
            if url.endswith(".grb2") or url.endswith(".grib2"):
                if len(b) >= 4 and b[:4] == b"GRIB":
                    return b
            elif url.endswith(".gz"):
                if len(b) >= 2 and b[:2] == b"\x1f\x8b":
                    return b
            else:
                # If some other suffix shows up, just return on 200 OK
                return b
        time.sleep(sleep * (1.5 ** k))
    raise RuntimeError(f"Failed to fetch expected content from {url}")



def expected_flxf_filenames(init_yyyymmdd: str, init_cycle: str, member: int) -> List[str]:
    """
    CFSv2 6-hr flux filenames follow:
    flxfYYYYmmddHH.mm.VALIDyyyymmddHH.grb2
      where 'mm' is ensemble member 01..04, rightmost timestamp is init,
      left timestamp is 'valid time' every 6 hours.
    """
    init = dt.datetime.strptime(init_yyyymmdd + init_cycle, "%Y%m%d%H").replace(tzinfo=dt.timezone.utc)
    names = []
    # 6-hour steps through first 30 days (inclusive end is OK; NOMADS may lag)
    steps = int((MAX_FCST_DAYS * 24) / 6)
    for s in range(0, steps + 1):
        valid = init + dt.timedelta(hours=6 * s)
        left = valid.strftime("%Y%m%d%H")
        right = init.strftime("%Y%m%d%H")
        names.append(f"flxf{left}.{member:02d}.{right}.grb2")
    return names


def mask_to_basin(da: xr.DataArray, basin: gpd.GeoSeries) -> xr.DataArray:
    """
    Reproject basin to dataset CRS and mask the raster.
    """
    if "latitude" in da.coords and "longitude" in da.coords:
        da = da.rename({"latitude": "lat", "longitude": "lon"})
    # ensure CRS
    da = da.rio.write_crs("EPSG:4326", inplace=True)
    # vector to same CRS
    basin = basin.to_crs("EPSG:4326")
    masked = da.rio.clip(basin.geometry.apply(mapping), basin.crs, drop=False)
    return masked

def hour_end_from_mrms_name(url: str) -> pd.Timestamp:
    m = re.search(r"_(\d{8})-(\d{2})0000", url)
    if not m:
        raise RuntimeError(f"Cannot parse MRMS hour from {url}")
    return pd.to_datetime(m.group(1) + m.group(2), format="%Y%m%d%H", utc=True)

def area_average_mm(da_masked: xr.DataArray) -> pd.Series:
    """
    Cosine-lat weighted basin mean.
    Returns:
      - a pandas Series indexed by time if 'time' dim exists
      - a length-1 Series if there is no 'time' dim
    """
    arr = da_masked.where(np.isfinite(da_masked))

    if "lat" not in arr.coords or "lon" not in arr.coords:
        raise ValueError("Expected 'lat' and 'lon' coordinates.")

    # weights ~ cos(lat), broadcast over lon (and time, if present)
    w_lat = np.cos(np.deg2rad(arr["lat"]))
    w = w_lat.broadcast_like(arr)

    # weighted spatial mean
    num = (arr * w).sum(dim=("lat", "lon"), skipna=True)
    den = w.sum(dim=("lat", "lon"), skipna=True)
    mean_da = num / den

    if "time" in mean_da.dims:
        s = mean_da.to_series()
        s.name = arr.name
        return s
    else:
        # single timestamp file -> scalar; return length-1 Series to match callers
        val = float(mean_da.values)
        s = pd.Series([val], name=arr.name)
        return s



def list_mrms_hourlies(start_utc: pd.Timestamp, end_utc: pd.Timestamp) -> List[str]:
    """
    Build list of MRMS hourly files covering [start, end].
    File pattern: MRMS_MultiSensor_QPE_01H_Pass2_00.00_YYYYMMDD-HH0000.grib2.gz
    """
    hours = pd.date_range(start_utc.floor("H"), end_utc.ceil("H"), freq="H", tz="UTC")
    return [
        f"{MRMS_BASE}/MRMS_MultiSensor_QPE_01H_Pass2_00.00_{t.strftime('%Y%m%d-%H0000')}.grib2.gz"
        for t in hours
    ]

def open_mrms_hourly_mm_gz(gz_bytes: bytes) -> xr.DataArray:
    """
    Open MRMS hourly QPE GRIB2.gz as mm/hour accumulation over previous hour.
    """
    with gzip.GzipFile(fileobj=io.BytesIO(gz_bytes)) as gz:
        with xr.open_dataset(gz, engine="cfgrib") as ds:
            # MRMS uses 'unknown' variable names sometimes; try typical names
            # Many MRMS GRIBs carry 'unknown' shortName with discipline=Hydrology.
            # Pull the first data variable.
            vname = [k for k in ds.data_vars][0]
            da = ds[vname].rename("mrms_qpe_1h")
            da.attrs["units"] = "mm"
            da = da.rename({"latitude": "lat", "longitude": "lon"})
            return da.load()

def aggregate_mrms_to_6h(mm_1h_series: Dict[pd.Timestamp, float]) -> pd.Series:
    """
    Sum hourly MRMS to 6-h windows ending at 00,06,12,18 UTC to align with CFS 6h valid times.
    """
    s = pd.Series(mm_1h_series, dtype=float)
    s.index = pd.to_datetime(s.index, utc=True)
    # Define windows with labels at window end
    s6 = s.resample("6H", label="right", closed="right").sum(min_count=1)
    s6.name = "mrms_prcp_6h"
    return s6

def evaluate_init_error(cfs_6h: pd.Series, mrms_6h: pd.Series, init_utc: pd.Timestamp) -> pd.DataFrame:
    """
    Merge series and compute error metrics by lead hours.
    """
    df = pd.concat([cfs_6h.rename("cfs"), mrms_6h.rename("obs")], axis=1).dropna()
    df["lead_hours"] = (df.index.tz_convert("UTC") - init_utc).total_seconds() / 3600.0
    df = df[(df["lead_hours"] >= 0) & (df["lead_hours"] <= MAX_FCST_DAYS * 24)]
    df["err"] = df["cfs"] - df["obs"]
    return df

def bin_metrics(df: pd.DataFrame, bins_hours=(0, 6, 12, 24, 48, 72, 120, 240, 360, 720)) -> pd.DataFrame:
    """
    Compute mean error (bias), MAE, RMSE by lead bins.
    """
    cats = pd.cut(df["lead_hours"], bins=bins_hours, right=True, include_lowest=True)
    out = df.groupby(cats).agg(
        n=("err", "count"),
        bias_mm=("err", "mean"),
        mae_mm=("err", lambda x: np.mean(np.abs(x))),
        rmse_mm=("err", lambda x: math.sqrt(np.mean(x**2))),
        obs_mm=("obs", "mean"),
        fcst_mm=("cfs", "mean")
    ).reset_index().rename(columns={"lead_hours": "lead_bin"})
    return out

# -----------------------
# Main driver
# -----------------------
MAX_MISS = 8

def valid_time_from_name(fn: str) -> pd.Timestamp:
    m = re.search(r"flxf(\d{10})\.\d{2}\.(\d{10})\.grb2$", fn)
    if not m:
        raise RuntimeError(f"Cannot parse valid time from {fn}")
    return pd.to_datetime(m.group(1), format="%Y%m%d%H", utc=True)


def run_one_cycle(yyyymmdd: str = None, cycle: str = None):
    if yyyymmdd is None or cycle is None:
        yyyymmdd, cycle = latest_cfs_cycle()

    init_utc = pd.to_datetime(f"{yyyymmdd}{cycle}", format="%Y%m%d%H", utc=True)

    # Get basins
    basins = {
        "HUC2_07_UpperMiss": get_huc2_polygon("07"),
        "HUC2_10_Missouri": get_huc2_polygon("10"),
    }

    all_metrics = []

    for member in [1, 2, 3, 4]:
        member_dir = list_cfs_member_dir(yyyymmdd, cycle, member)
        fnames = expected_flxf_filenames(yyyymmdd, cycle, member)

        # We'll only iterate until files stop existing
        # Build time span to retrieve MRMS later
        valid_times = []
        cfs_area_means = {k: [] for k in basins.keys()}

        print(f"\nMember {member:02d}: downloading CFSv2 PRATE as 6h accumulations...")
        miss_streak = 0
        for fn in tqdm(fnames, ncols=100):
            url = member_dir + fn
            try:
                grib = fetch_url(url)
                miss_streak = 0
            except Exception:
                miss_streak += 1
                if miss_streak >= MAX_MISS:
                    break
                continue

            sixh = open_cfs_precip_6h(grib)  # mm per 6h
            t_valid = valid_time_from_name(fn)  # use filename as single source of truth
            valid_times.append(t_valid)

            for bname, poly in basins.items():
                masked = mask_to_basin(sixh, poly)
                mean_series = area_average_mm(masked)  # one-time step -> one value
                cfs_area_means[bname].append(float(mean_series.iloc[0]))

        if not valid_times:
            print(f"Member {member:02d}: no files found, skipping.")
            continue

        # Build series per basin
        for bname in basins.keys():
            # Ensure same length
            n = min(len(cfs_area_means[bname]), len(valid_times))
            if n == 0:
                continue
            times_n = pd.DatetimeIndex(valid_times[:n], tz="UTC")
            vals_n = cfs_area_means[bname][:n]
            cfs_ser = pd.Series(vals_n, index=times_n, name=f"cfs_{bname}_mm6h")

            # MRMS: gather hourlies spanning the same window then aggregate to 6h
            # MRMS: gather hourlies spanning the same window then aggregate to 6h
            mrms_urls = list_mrms_hourlies(cfs_ser.index.min() - pd.Timedelta(hours=6),
                                        cfs_ser.index.max())
            mrms_hourly_vals = {}
            print(f"Member {member:02d} {bname}: downloading MRMS hourlies ({len(mrms_urls)} files)...")

            for u in tqdm(mrms_urls, ncols=100):
                try:
                    gz = fetch_url(u)                      # .gz bytes
                    da = open_mrms_hourly_mm_gz(gz)        # mm per hour
                    masked = mask_to_basin(da, basins[bname])
                    mean_hour = float(masked.mean(dim=("lat","lon"), skipna=True).values)
                    tcoord = hour_end_from_mrms_name(u)    # hour-end timestamp from filename
                    mrms_hourly_vals[tcoord] = mean_hour
                except Exception:
                    continue

            mrms6 = aggregate_mrms_to_6h(mrms_hourly_vals)
            # align to CFS valid times
            mrms6 = mrms6.reindex(cfs_ser.index, method=None)

    if not all_metrics:
        raise SystemExit("No metrics computed. Likely the cycle has not fully populated yet.")

    metrics = pd.concat(all_metrics, ignore_index=True)
    # Tidy lead bin labels
    metrics["lead_bin_h"] = metrics["lead_bin"].astype(str).str.replace(r"\(|\]","",regex=True)
    # Save artifacts
    out_dir = DOWNLOAD_DIR / f"cfs_{yyyymmdd}{cycle}"
    out_dir.mkdir(parents=True, exist_ok=True)
    metrics.to_csv(out_dir / "init_error_metrics_by_bin.csv", index=False)

    # Also produce a quick wide summary: mean over members by basin and lead bin
    summary = (metrics
               .groupby(["basin","lead_bin_h"], as_index=False)
               .agg(bias_mm=("bias_mm","mean"),
                    mae_mm=("mae_mm","mean"),
                    rmse_mm=("rmse_mm","mean"),
                    n=("n","sum")))
    summary.to_csv(out_dir / "init_error_summary_mean_over_members.csv", index=False)
    print(f"\nWrote:\n  {out_dir/'init_error_metrics_by_bin.csv'}\n  {out_dir/'init_error_summary_mean_over_members.csv'}")
    print("Done.")

if __name__ == "__main__":
    # Optional CLI args: YYYYMMDD CC
    ymd = sys.argv[1] if len(sys.argv) > 1 else None
    cyc = sys.argv[2] if len(sys.argv) > 2 else None
    run_one_cycle(ymd, cyc)



Member 01: downloading CFSv2 PRATE as 6h accumulations...


  0%|                                                                       | 0/121 [00:00<?, ?it/s]Ignoring index file ':memory:' older than GRIB file
Ignoring index file ':memory:' older than GRIB file
  1%|▌                                                              | 1/121 [00:01<02:18,  1.15s/it]Ignoring index file ':memory:' older than GRIB file
Ignoring index file ':memory:' older than GRIB file
  2%|█                                                              | 2/121 [00:02<02:17,  1.15s/it]Ignoring index file ':memory:' older than GRIB file
Ignoring index file ':memory:' older than GRIB file
  2%|█▌                                                             | 3/121 [00:03<02:16,  1.16s/it]Ignoring index file ':memory:' older than GRIB file
Ignoring index file ':memory:' older than GRIB file
  3%|██                                                             | 4/121 [00:04<02:16,  1.17s/it]Ignoring index file ':memory:' older than GRIB file
Ignoring index file ':memory:' o

Member 01 HUC2_07_UpperMiss: downloading MRMS hourlies (727 files)...


100%|██████████████████████████████████████████████████████████| 727/727 [13:11:17<00:00, 65.31s/it]
  s6 = s.resample("6H", label="right", closed="right").sum(min_count=1)
  hours = pd.date_range(start_utc.floor("H"), end_utc.ceil("H"), freq="H", tz="UTC")


Member 01 HUC2_10_Missouri: downloading MRMS hourlies (727 files)...


 34%|███████████████████▌                                     | 249/727 [4:10:25<8:58:23, 67.58s/it]