In [1]:
# preprocess_LSTM_dask.py
# -----------------------------------------------------------
# ❶  Imports & Dask cluster
# -----------------------------------------------------------
import torch
import zarr
import xarray as xr
import geopandas as gpd
import pandas as pd
import numpy as np
import os, shutil, json, tempfile
from pathlib import Path
from collections import Counter, defaultdict
from dask import delayed, compute
from dask.distributed import LocalCluster, Client, performance_report
from IPython.display import display # Dask dashboard inside the notebook itself:
import pickle
import dask
import json

# -----------------------------------------------------------
# Robust downloader
# -----------------------------------------------------------
import urllib.error

def get_usgs_streamflow(site, start_date, end_date, min_end_date="2024-12-31"):
    """
    Download daily streamflow data from USGS NWIS for a given site and date range.
    Assumes columns '20d' (date) and '14n' (flow in cfs).
    
    Returns:
        pd.DataFrame or None if download fails or structure is unexpected
    """
    url = (
        "https://waterservices.usgs.gov/nwis/dv/"
        f"?format=rdb&sites={site}&startDT={start_date}&endDT={end_date}"
        "&parameterCd=00060&siteStatus=all"
    )

    try:
        df = pd.read_csv(url, comment="#", sep="\t", header=1, parse_dates=["20d"])
    except Exception as e:
        print(f"[{site}] failed to download: {e}; skipping")
        return None

    if "14n" not in df.columns or "20d" not in df.columns:
        print(f"[{site}] missing expected columns '20d' and '14n'; skipping")
        return None

    df = df.rename(columns={"14n": "streamflow_cfs", "20d": "date"})
    df["streamflow_cfs"] = pd.to_numeric(df["streamflow_cfs"], errors="coerce")

    # Remove rows with NaNs
    df = df.dropna(subset=["streamflow_cfs"])
    if df.empty:
        print(f"[{site}] all streamflow data missing or invalid; skipping")
        return None

    # Check time coverage
    if pd.to_datetime(df["date"].max()) < pd.to_datetime(min_end_date):
        print(f"[{site}] data ends at {df['date'].max()}, < {min_end_date}; skipping")
        return None

    # Convert to cubic meters per second (cms)
    df["streamflow_cms"] = df["streamflow_cfs"] * 0.0283168
    df = df[["date", "streamflow_cms"]].set_index("date").sort_index()

    return df
# # Example usage:
# # site_id = gaugeID  # Example gauge ID
# site_id = '09085000'

# start = '2015-01-01'
# end = '2024-12-31'

# streamflow_data = get_usgs_streamflow(site_id, start, end)
# print(streamflow_data.tail())

def get_or_download_streamflows(df, start_date="2015-01-01", end_date="2024-12-31"):
    streamflow_file = FINAL_OUT / "streamflows.pkl"
    skipped_file    = FINAL_OUT / "skipped_gauges.txt"

    if streamflow_file.exists() and skipped_file.exists():
        print("🔁 Loading cached streamflows and skipped gauges...")
        with open(streamflow_file, "rb") as f:
            streamflows = pickle.load(f)
        with open(skipped_file, "r") as f:
            skipped_gauges = [line.strip() for line in f]
    else:
        print("⬇️  Downloading streamflows from USGS...")
        streamflows = {}
        skipped_gauges = []
        gauge_ids = df["gauge_id"].str.split("_").str[-1].tolist()

        for g in gauge_ids:
            dfQ = get_usgs_streamflow(g, start_date, end_date)
            if dfQ is None:
                skipped_gauges.append(g)
            else:
                streamflows[g] = dfQ

        # Save results
        FINAL_OUT.mkdir(exist_ok=True)
        with open(streamflow_file, "wb") as f:
            pickle.dump(streamflows, f)
        with open(skipped_file, "w") as f:
            f.write("\n".join(skipped_gauges))
        print(f"✅ Saved streamflows to {streamflow_file}")
        print(f"❌ Saved skipped gauges to {skipped_file}")
        
    return streamflows, skipped_gauges


# CAMELS basins
df=gpd.read_file('/Projects/HydroMet/currierw/Caravan-Jan25-csv/shapefiles/camels/camels_basin_shapes.shp')

In [2]:
import os, multiprocessing, torch, psutil

print("Logical CPU cores :", multiprocessing.cpu_count())

if torch.cuda.is_available():
    print("CUDA device      :", torch.cuda.get_device_name(0))
    print("GPU capability   :", torch.cuda.get_device_capability(0))
else:
    print("No CUDA‑capable GPU detected")

print("RAM (GB total)   :", round(psutil.virtual_memory().total / 1e9, 1))

Logical CPU cores : 256
CUDA device      : NVIDIA A100-SXM4-40GB
GPU capability   : (8, 0)
RAM (GB total)   : 1082.0


In [3]:
# -----------------------------------------------------------
# ❷  Constants & helpers
# -----------------------------------------------------------
BASE_OBS  = Path('/Projects/HydroMet/currierw/ERA5_LAND')
BASE_FCST = Path('/Projects/HydroMet/currierw/HRES')
SCRATCH   = Path('/Projects/HydroMet/currierw/HRES_processed_tmp')  # will be recreated
FINAL_OUT = Path('/Projects/HydroMet/currierw/HRES_processed')

FORECAST_BLOCKS = {
    "train":      pd.date_range('2016-01-01', '2020-09-30', freq='5D'),
    "validation": pd.date_range('2020-10-01', '2022-09-30', freq='5D'),
    "test":       pd.date_range('2022-10-01', '2024-09-30', freq='5D'),
}

REQUIRED_KEYS = ['precip', 'temp', 'net_solar', 'flow', 'target']
EXPECTED_LEN  = 106          # enforce the length we know is correct

# ERA5_ZARR = None
# HRES_ZARR = None

print("Opening Zarr datasets once and broadcasting to workers...")
ERA5_ZARR = xr.open_zarr(BASE_OBS / 'camels_rechunked.zarr', consolidated=True, chunks={'date': 365})
HRES_ZARR = xr.open_zarr(BASE_FCST / 'camels_rechunked.zarr', consolidated=True, decode_timedelta=True, chunks={'date': 365})

def standardize_tensor(arr, mean, std):
    return (arr - mean) / std

# -----------------------------------------------------------
# ❸  Per–gauge worker
# -----------------------------------------------------------
# @delayed
import time
@dask.delayed
def build_sample(gauge_id, fcst_date, ERA5_ZARR, HRES_ZARR, df_streamflow):
    try:
        # Load obs data slices for this gauge & needed dates
        ds_obs = ERA5_ZARR.sel(basin=f'camels_{gauge_id}')
        ds_fcst = HRES_ZARR.sel(basin=f'camels_{gauge_id}')
        
        start_weekly = fcst_date - pd.Timedelta(days=305)
        end_weekly = fcst_date - pd.Timedelta(days=60) - pd.Timedelta(days=1)
        start_daily = fcst_date - pd.Timedelta(days=60)
        end_daily = fcst_date - pd.Timedelta(days=1)
        start_fore = fcst_date
        end_fore = fcst_date + pd.Timedelta(days=10)
        
        # Streamflow slices
        q_weekly = df_streamflow.loc[start_weekly:end_weekly]['streamflow_cms'].resample('7D').mean()
        q_daily  = df_streamflow.loc[start_daily:end_daily]['streamflow_cms']
        q_fore   = df_streamflow.loc[start_fore:end_fore]['streamflow_cms']
        q_combined = pd.concat([q_weekly, q_daily, q_fore]).to_xarray()
        q_combined.name = 'streamflow'
        
        # Load obs variables just for needed dates and load eagerly (small slices)
        obs_weekly_p = ds_obs['era5land_total_precipitation'].sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean().load()
        obs_weekly_t = ds_obs['era5land_temperature_2m'].sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean().load()
        obs_weekly_s = ds_obs['era5land_surface_net_solar_radiation'].sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean().load()
        
        obs_daily_p  = ds_obs['era5land_total_precipitation'].sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1))).load()
        obs_daily_t  = ds_obs['era5land_temperature_2m'].sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1))).load()
        obs_daily_s  = ds_obs['era5land_surface_net_solar_radiation'].sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1))).load()
        
        # Load forecast for this date and expand lead time
        tmp = ds_fcst.sel(date=fcst_date, method='nearest').load()
        fcst_dates_expand = pd.Timestamp(tmp.date.values) + pd.to_timedelta(tmp.lead_time.values)
        tmp = tmp.assign_coords(date=('lead_time', fcst_dates_expand))
        fcst = tmp.swap_dims({'lead_time':'date'}).drop_vars('lead_time').isel(date=slice(0,10))
        
        fcst_p = fcst['hres_total_precipitation']
        fcst_t = fcst['hres_temperature_2m']
        fcst_s = fcst['hres_surface_net_solar_radiation']
        
        # Combine obs and forecast
        precip = xr.concat([obs_weekly_p, obs_daily_p, fcst_p], dim='date')
        temp   = xr.concat([obs_weekly_t, obs_daily_t, fcst_t], dim='date')
        nsrad  = xr.concat([obs_weekly_s, obs_daily_s, fcst_s], dim='date')
        
        if precip.shape[0] != EXPECTED_LEN or q_combined.shape[0] != EXPECTED_LEN:
            return None
        
        flags = np.concatenate([
            np.full(obs_weekly_p.date.size, 0),
            np.full(obs_daily_p.date.size, 1),
            np.full(fcst_p.date.size, 2)
        ])
        
        sample = {
            'precip': precip.values.astype(np.float32),
            'temp': temp.values.astype(np.float32),
            'net_solar': nsrad.values.astype(np.float32),
            'flag': flags.astype(np.int8),
            'flow': q_combined.values.astype(np.float32),
            'target': q_combined.values.astype(np.float32),
            'basin_id': gauge_id,
            'forecast_date': fcst_date.strftime('%Y-%m-%d')
        }
        return sample

    except Exception as e:
        print(f"[{gauge_id}] skip {fcst_date:%Y-%m-%d}: {e}")
        return None




def process_block(gauge_id, df_streamflow, split, fcst_dates, ERA5_ZARR, HRES_ZARR):
    t0 = time.time()
    out_files = []
    try:
        # Submit all build_sample tasks in parallel
        delayed_samples = [
            build_sample(gauge_id, fcst_date, ERA5_ZARR, HRES_ZARR, df_streamflow)
            for fcst_date in fcst_dates
        ]
        # Compute all samples in parallel
        samples = dask.compute(*delayed_samples)
        # Filter out None results (failed samples)
        samples = [s for s in samples if s is not None]

        if not samples:
            print(f"[{gauge_id}] no valid samples for {split} block")
            return []

        ds = samples_to_xarray(samples)
        outfile = SCRATCH / f'{split}_{gauge_id}_{fcst_dates[0].strftime("%Y%m%d")}.nc'
        ds.to_netcdf(outfile)
        print(f"[{gauge_id}] wrote file: {outfile}")
        out_files.append(str(outfile))
        
        print(f"[{gauge_id}] block processed in {time.time() - t0:.2f}s")

    except Exception as e:
        print(f"[{gauge_id}] failed with error: {e}")

    return out_files



# -----------------------------------------------------------
# ❹  Small helper: samples → xarray (single gauge, single split)
# -----------------------------------------------------------
def samples_to_xarray(samples):
    n = len(samples)
    dyn_inputs = np.zeros((n, EXPECTED_LEN, 4), np.float32)  # features: precip, temp, net_solar, flag
    targets = np.zeros((n, EXPECTED_LEN, 1), np.float32)     # target: streamflow
    basin_ids = np.empty(n, 'U20')
    forecast_dates = np.empty(n, 'U20')

    for i, s in enumerate(samples):
        dyn_inputs[i, :, 0] = s['precip']
        dyn_inputs[i, :, 1] = s['temp']
        dyn_inputs[i, :, 2] = s['net_solar']
        dyn_inputs[i, :, 3] = s['flag'].astype(np.float32)  # flag as float for model
        targets[i, :, 0] = s['target']
        basin_ids[i] = s['basin_id']
        forecast_dates[i] = s['forecast_date']

    ds = xr.Dataset(
        {
            "dynamic_inputs": (
                ["sample", "time", "feature"],
                dyn_inputs,
                {"feature": ["precip", "temp", "net_solar", "flag"]}
            ),
            "targets": (
                ["sample", "time", "target"],
                targets,
                {"target": ["streamflow"]}
            ),
            "basin_id": (["sample"], basin_ids),
            "forecast_date": (["sample"], forecast_dates)
        }
    )

    ds.attrs["flag_description"] = json.dumps({
        "0": "weekly reanalysis (ERA5)",
        "1": "daily reanalysis (ERA5)",
        "2": "forecast (HRES)"
    })
    return ds


# -----------------------------------------------------------
# ❺  Parent / driver
# -----------------------------------------------------------
def main():
    # ---------- (re-)create scratch folder ----------
    if SCRATCH.exists():
        shutil.rmtree(SCRATCH)
    SCRATCH.mkdir(parents=True)

    # ---------- build gauge list ----------
    gauge_ids = df["gauge_id"].str.split("_").str[-1].tolist()
    print('have gauge list')
    # ---------- download streamflow (skip bad gauges) ----------
    streamflows, skipped_gauges = get_or_download_streamflows(df)
    import sys
    print(str(sys.getsizeof(streamflows['01013500'])*1e-6)+' mb')
    print(f"✅ {len(streamflows)} gauges ready, ❌ {len(skipped_gauges)} skipped")
    print('Got the Gauges: starting parallelization')

    # ---------- start Dask cluster ----------
    cluster = LocalCluster(n_workers=8, threads_per_worker=1,
                           processes=True, memory_limit="4GB")
    # client  = Client(cluster)
    client = Client(cluster, timeout="120s")

    display(client)  # shows dashboard link in Jupyter

    # ---------- dispatch only valid gauges ----------
    CHUNK = 100  # try 20, 40, 60 etc.
    
    def chunks(seq, n):
        for i in range(0, len(seq), n):
            yield seq[i:i + n]
    
    futures = []
    
    some_gauges = list(streamflows.keys())[:8]  # or full list if ready

    scattered_streamflows = {
        g: client.scatter(dfQ, broadcast=False) for g, dfQ in streamflows.items()
    }


    ERA5_ZARR_scattered = client.scatter(ERA5_ZARR, broadcast=True, hash=False)
    HRES_ZARR_scattered = client.scatter(HRES_ZARR, broadcast=True, hash=False)
    # broadcast=True: ensures all workers get a copy.
    # hash=False: prevents Dask from trying to hash the big object (xarray datasets can be large and non-hashable).

    for g in some_gauges:
        dfQ = scattered_streamflows[g]
        for split, dates in FORECAST_BLOCKS.items():
            for sub_dates in chunks(dates, CHUNK):
                print(f"[submit] gauge: {g}, split: {split}, {sub_dates[0]} to {sub_dates[-1]}")
                fut = client.submit(process_block, g, dfQ, split, sub_dates, ERA5_ZARR_scattered, HRES_ZARR_scattered, pure=False)
                futures.append(fut)
    
    print("All submits done")
    with performance_report(filename="report.html"):
        results = client.gather(futures)
        print("Files created:", results)
    print("Finished Dask computation")


    # ---------- concatenate split‑level files ----------
    print('concatenating split-level files from parallelization')
    for split in ['train','validation','test']:
        files = sorted(SCRATCH.glob(f'{split}_*.nc'))
        if not files: continue
        ds = xr.open_mfdataset(files, combine='nested',
                               concat_dim='sample',
                               parallel=True)
        ds.to_netcdf(
            FINAL_OUT / f"{split}_data_ERA5_HRES_CAMELS_unstandardized.nc",
            encoding={
                "dynamic_inputs": {"zlib": True, "complevel": 4},
                "targets":        {"zlib": True, "complevel": 4}
            },
        )
        print(f"[✓] wrote {split} set ({len(files)} gauges)")

    # ---------- at the very end ----------
    if skipped_gauges:
        print("Skipped gauges:", ", ".join(skipped_gauges))
        with open(FINAL_OUT / "skipped_gauges.txt", "w") as fp:
            fp.write("\n".join(skipped_gauges))
            
    print("All done.  See NetCDFs in", FINAL_OUT)


# if __name__ == "__main__":
#     main()
#     # This is only needed if you're running the script directly from a terminal or !python — not inside a notebook.
#     # This design allows the file to serve two purposes:
#     # Be run directly as a standalone program.
#     # Be imported as a library/module into another script without executing the main logic.



Opening Zarr datasets once and broadcasting to workers...


In [None]:
main()

have gauge list
🔁 Loading cached streamflows and skipped gauges...
0.058463999999999995 mb
✅ 587 gauges ready, ❌ 84 skipped
Got the Gauges: starting parallelization


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 29.80 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:34997,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 29.80 GiB

0,1
Comm: tcp://127.0.0.1:41015,Total threads: 1
Dashboard: http://127.0.0.1:41599/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:36939,
Local directory: /tmp/dask-scratch-space-4201/worker-kd_6gc8o,Local directory: /tmp/dask-scratch-space-4201/worker-kd_6gc8o

0,1
Comm: tcp://127.0.0.1:35773,Total threads: 1
Dashboard: http://127.0.0.1:35587/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:34451,
Local directory: /tmp/dask-scratch-space-4201/worker-4r56bthe,Local directory: /tmp/dask-scratch-space-4201/worker-4r56bthe

0,1
Comm: tcp://127.0.0.1:37507,Total threads: 1
Dashboard: http://127.0.0.1:40649/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:35365,
Local directory: /tmp/dask-scratch-space-4201/worker-k_iero15,Local directory: /tmp/dask-scratch-space-4201/worker-k_iero15

0,1
Comm: tcp://127.0.0.1:38121,Total threads: 1
Dashboard: http://127.0.0.1:35555/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:34665,
Local directory: /tmp/dask-scratch-space-4201/worker-y0qmjb7f,Local directory: /tmp/dask-scratch-space-4201/worker-y0qmjb7f

0,1
Comm: tcp://127.0.0.1:44765,Total threads: 1
Dashboard: http://127.0.0.1:37217/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:41855,
Local directory: /tmp/dask-scratch-space-4201/worker-3sgkp40x,Local directory: /tmp/dask-scratch-space-4201/worker-3sgkp40x

0,1
Comm: tcp://127.0.0.1:45583,Total threads: 1
Dashboard: http://127.0.0.1:36835/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:33009,
Local directory: /tmp/dask-scratch-space-4201/worker-uszh9unc,Local directory: /tmp/dask-scratch-space-4201/worker-uszh9unc

0,1
Comm: tcp://127.0.0.1:38167,Total threads: 1
Dashboard: http://127.0.0.1:33791/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:35109,
Local directory: /tmp/dask-scratch-space-4201/worker-a5tgzjhc,Local directory: /tmp/dask-scratch-space-4201/worker-a5tgzjhc

0,1
Comm: tcp://127.0.0.1:41415,Total threads: 1
Dashboard: http://127.0.0.1:33579/status,Memory: 3.73 GiB
Nanny: tcp://127.0.0.1:33041,
Local directory: /tmp/dask-scratch-space-4201/worker-01fr6fhq,Local directory: /tmp/dask-scratch-space-4201/worker-01fr6fhq


[submit] gauge: 01013500, split: train, 2016-01-01 00:00:00 to 2017-05-10 00:00:00
[submit] gauge: 01013500, split: train, 2017-05-15 00:00:00 to 2018-09-22 00:00:00
[submit] gauge: 01013500, split: train, 2018-09-27 00:00:00 to 2020-02-04 00:00:00
[submit] gauge: 01013500, split: train, 2020-02-09 00:00:00 to 2020-09-26 00:00:00
[submit] gauge: 01013500, split: validation, 2020-10-01 00:00:00 to 2022-02-08 00:00:00
[submit] gauge: 01013500, split: validation, 2022-02-13 00:00:00 to 2022-09-26 00:00:00
[submit] gauge: 01013500, split: test, 2022-10-01 00:00:00 to 2024-02-08 00:00:00
[submit] gauge: 01013500, split: test, 2024-02-13 00:00:00 to 2024-09-30 00:00:00
[submit] gauge: 01022500, split: train, 2016-01-01 00:00:00 to 2017-05-10 00:00:00
[submit] gauge: 01022500, split: train, 2017-05-15 00:00:00 to 2018-09-22 00:00:00
[submit] gauge: 01022500, split: train, 2018-09-27 00:00:00 to 2020-02-04 00:00:00
[submit] gauge: 01022500, split: train, 2020-02-09 00:00:00 to 2020-09-26 00:00

In [None]:
# import pandas as pd
# import xarray as xr
# from pathlib import Path
# import shutil
# import time

# # preprocess_LSTM_dask.py
# # -----------------------------------------------------------
# # ❶  Imports & Dask cluster
# # -----------------------------------------------------------
# import torch
# import zarr
# import geopandas as gpd
# import numpy as np
# import os, shutil, json, tempfile
# from collections import Counter, defaultdict
# from dask import delayed, compute
# from dask.distributed import LocalCluster, Client, performance_report
# from IPython.display import display # Dask dashboard inside the notebook itself:
# import pickle
# import dask

# # ----- Constants & paths -----
# BASE_OBS  = Path('/Projects/HydroMet/currierw/ERA5_LAND')
# BASE_FCST = Path('/Projects/HydroMet/currierw/HRES')
# SCRATCH   = Path('/Projects/HydroMet/currierw/HRES_processed_tmp')
# FINAL_OUT = Path('/Projects/HydroMet/currierw/HRES_processed')

# FORECAST_BLOCKS = {
#     "train":      pd.date_range('2016-01-01', '2020-09-30', freq='5D'),
#     "validation": pd.date_range('2020-10-01', '2022-09-30', freq='5D'),
#     "test":       pd.date_range('2022-10-01', '2024-09-30', freq='5D'),
# }

# EXPECTED_LEN = 106  # length check for precip and flow arrays

# # ----- Your process_block and build_sample functions here -----
# # (Copy the full definitions of build_sample and process_block exactly from your code)

# # (For brevity, I'm assuming you have them defined in this script, otherwise copy-paste them here)

# # ----- Helper to load or download streamflow data for your gauge -----
# def get_or_download_streamflows(df, start_date="2015-01-01", end_date="2024-12-31"):
#     streamflow_file = FINAL_OUT / "streamflows.pkl"
#     skipped_file    = FINAL_OUT / "skipped_gauges.txt"

#     if streamflow_file.exists() and skipped_file.exists():
#         print("🔁 Loading cached streamflows and skipped gauges...")
#         with open(streamflow_file, "rb") as f:
#             streamflows = pickle.load(f)
#         with open(skipped_file, "r") as f:
#             skipped_gauges = [line.strip() for line in f]
#     else:
#         print("⬇️  Downloading streamflows from USGS...")
#         streamflows = {}
#         skipped_gauges = []
#         gauge_ids = df["gauge_id"].str.split("_").str[-1].tolist()

#         for g in gauge_ids:
#             dfQ = get_usgs_streamflow(g, start_date, end_date)
#             if dfQ is None:
#                 skipped_gauges.append(g)
#             else:
#                 streamflows[g] = dfQ

#         # Save results
#         FINAL_OUT.mkdir(exist_ok=True)
#         with open(streamflow_file, "wb") as f:
#             pickle.dump(streamflows, f)
#         with open(skipped_file, "w") as f:
#             f.write("\n".join(skipped_gauges))
#         print(f"✅ Saved streamflows to {streamflow_file}")
#         print(f"❌ Saved skipped gauges to {skipped_file}")
        
#     return streamflows, skipped_gauges


# # CAMELS basins
# df=gpd.read_file('/Projects/HydroMet/currierw/Caravan-Jan25-csv/shapefiles/camels/camels_basin_shapes.shp')

# def build_sample(gauge_id, fcst_date, ERA5_ZARR, HRES_ZARR, df_streamflow):
#     try:
#         # Load obs data slices for this gauge & needed dates
#         ds_obs = ERA5_ZARR.sel(basin=f'camels_{gauge_id}')
#         ds_fcst = HRES_ZARR.sel(basin=f'camels_{gauge_id}')
        
#         start_weekly = fcst_date - pd.Timedelta(days=305)
#         end_weekly = fcst_date - pd.Timedelta(days=60) - pd.Timedelta(days=1)
#         start_daily = fcst_date - pd.Timedelta(days=60)
#         end_daily = fcst_date - pd.Timedelta(days=1)
#         start_fore = fcst_date
#         end_fore = fcst_date + pd.Timedelta(days=10)
        
#         # Streamflow slices
#         q_weekly = df_streamflow.loc[start_weekly:end_weekly]['streamflow_cms'].resample('7D').mean()
#         q_daily  = df_streamflow.loc[start_daily:end_daily]['streamflow_cms']
#         q_fore   = df_streamflow.loc[start_fore:end_fore]['streamflow_cms']
#         q_combined = pd.concat([q_weekly, q_daily, q_fore]).to_xarray()
#         q_combined.name = 'streamflow'
        
#         # Load obs variables just for needed dates and load eagerly (small slices)
#         obs_weekly_p = ds_obs['era5land_total_precipitation'].sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean().load()
#         obs_weekly_t = ds_obs['era5land_temperature_2m'].sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean().load()
#         obs_weekly_s = ds_obs['era5land_surface_net_solar_radiation'].sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean().load()
        
#         obs_daily_p  = ds_obs['era5land_total_precipitation'].sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1))).load()
#         obs_daily_t  = ds_obs['era5land_temperature_2m'].sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1))).load()
#         obs_daily_s  = ds_obs['era5land_surface_net_solar_radiation'].sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1))).load()
        
#         # Load forecast for this date and expand lead time
#         tmp = ds_fcst.sel(date=fcst_date, method='nearest').load()
#         fcst_dates_expand = pd.Timestamp(tmp.date.values) + pd.to_timedelta(tmp.lead_time.values)
#         tmp = tmp.assign_coords(date=('lead_time', fcst_dates_expand))
#         fcst = tmp.swap_dims({'lead_time':'date'}).drop_vars('lead_time').isel(date=slice(0,10))
        
#         fcst_p = fcst['hres_total_precipitation']
#         fcst_t = fcst['hres_temperature_2m']
#         fcst_s = fcst['hres_surface_net_solar_radiation']
        
#         # Combine obs and forecast
#         precip = xr.concat([obs_weekly_p, obs_daily_p, fcst_p], dim='date')
#         temp   = xr.concat([obs_weekly_t, obs_daily_t, fcst_t], dim='date')
#         nsrad  = xr.concat([obs_weekly_s, obs_daily_s, fcst_s], dim='date')
        
#         if precip.shape[0] != EXPECTED_LEN or q_combined.shape[0] != EXPECTED_LEN:
#             return None
        
#         flags = np.concatenate([
#             np.full(obs_weekly_p.date.size, 0),
#             np.full(obs_daily_p.date.size, 1),
#             np.full(fcst_p.date.size, 2)
#         ])
        
#         sample = {
#             'precip': precip.values.astype(np.float32),
#             'temp': temp.values.astype(np.float32),
#             'net_solar': nsrad.values.astype(np.float32),
#             'flag': flags.astype(np.int8),
#             'flow': q_combined.values.astype(np.float32),
#             'target': q_combined.values.astype(np.float32),
#             'basin_id': gauge_id,
#             'forecast_date': fcst_date.strftime('%Y-%m-%d')
#         }
#         return sample

#     except Exception as e:
#         print(f"[{gauge_id}] skip {fcst_date:%Y-%m-%d}: {e}")
#         return None



# def process_block(gauge_id, df_streamflow, split, fcst_dates, ERA5_ZARR, HRES_ZARR):
#     t0 = time.time()
#     out_files = []
#     try:
#         delayed_samples = [
#             build_sample(gauge_id, fcst_date, ERA5_ZARR, HRES_ZARR, df_streamflow)
#             for fcst_date in fcst_dates
#         ]
#         # Compute all samples in parallel (here it’s just sequential call since no dask)
#         samples = dask.compute(*delayed_samples)
#         samples = [s for s in samples if s is not None]
        
#         print(f"[{gauge_id}][{split}] {len(samples)} valid samples out of {len(fcst_dates)} dates")

#         if not samples:
#             print(f"[{gauge_id}] no valid samples for {split} block")
#             return []

#         ds = samples_to_xarray(samples)
#         print(f"[{gauge_id}] Dataset created with shape {ds['dynamic_inputs'].shape}")

#         outfile = SCRATCH / f'{split}_{gauge_id}_{fcst_dates[0].strftime("%Y%m%d")}.nc'
#         ds.to_netcdf(outfile)
#         out_files.append(str(outfile))
        
#         print(f"[{gauge_id}] block processed in {time.time() - t0:.2f}s")

#     except Exception as e:
#         print(f"[{gauge_id}] failed with error: {e}")

#     return out_files


# def samples_to_xarray(samples):
#     n = len(samples)
#     dyn_inputs = np.zeros((n, EXPECTED_LEN, 4), np.float32)  # features: precip, temp, net_solar, flag
#     targets = np.zeros((n, EXPECTED_LEN, 1), np.float32)     # target: streamflow
#     basin_ids = np.empty(n, 'U20')
#     forecast_dates = np.empty(n, 'U20')

#     for i, s in enumerate(samples):
#         dyn_inputs[i, :, 0] = s['precip']
#         dyn_inputs[i, :, 1] = s['temp']
#         dyn_inputs[i, :, 2] = s['net_solar']
#         dyn_inputs[i, :, 3] = s['flag'].astype(np.float32)  # flag as float for model
#         targets[i, :, 0] = s['target']
#         basin_ids[i] = s['basin_id']
#         forecast_dates[i] = s['forecast_date']

#     ds = xr.Dataset(
#         {
#             "dynamic_inputs": (
#                 ["sample", "time", "feature"],
#                 dyn_inputs,
#                 {"feature": ["precip", "temp", "net_solar", "flag"]}
#             ),
#             "targets": (
#                 ["sample", "time", "target"],
#                 targets,
#                 {"target": ["streamflow"]}
#             ),
#             "basin_id": (["sample"], basin_ids),
#             "forecast_date": (["sample"], forecast_dates)
#         }
#     )
#     import json

#     ds.attrs["flag_description"] = json.dumps({
#         "0": "weekly reanalysis (ERA5)",
#         "1": "daily reanalysis (ERA5)",
#         "2": "forecast (HRES)"
#     })
#     return ds
    
# # ----- Main test routine -----
# import dask

# # Example dataframe df with gauge IDs -- replicate or load your real one
# df = pd.DataFrame({
#     'gauge_id': ['camels_01013500']  # Add more gauges as needed for test
# })

# # Prepare scratch directory
# if SCRATCH.exists():
#     shutil.rmtree(SCRATCH)
# SCRATCH.mkdir(parents=True)

# # Open zarr datasets once (adjust path if needed)
# print("Opening ERA5 Zarr dataset...")
# ERA5_ZARR = xr.open_zarr(BASE_OBS / 'camels_rechunked.zarr', consolidated=True, chunks={'date': 365})

# print("Opening HRES Zarr dataset...")
# HRES_ZARR = xr.open_zarr(BASE_FCST / 'camels_rechunked.zarr', consolidated=True, decode_timedelta=True, chunks={'date': 365})

# # Get streamflows dictionary
# streamflows, skipped_gauges = get_or_download_streamflows(df)
# print(f"Loaded streamflows for {len(streamflows)} gauges, skipped {len(skipped_gauges)}")

# # Pick one gauge for testing
# gauge_id = '01013500'  # from your example
# df_streamflow = streamflows[gauge_id]

# # Pick a split and first 10 forecast dates
# split = 'train'
# fcst_dates = list(FORECAST_BLOCKS[split])[:10]

# # Run process_block locally (no Dask)
# out_files = process_block(gauge_id, df_streamflow, split, fcst_dates, ERA5_ZARR, HRES_ZARR)

# print(f"Output NetCDF files for gauge {gauge_id}: {out_files}")