In [None]:
# preprocess_LSTM_dask.py
# -----------------------------------------------------------
# ❶  Imports & Dask cluster
# -----------------------------------------------------------
import torch
import zarr
import xarray as xr
import geopandas as gpd
import pandas as pd
import numpy as np
import os, shutil, json, tempfile
from pathlib import Path
from collections import Counter, defaultdict
from dask import delayed, compute
from dask.distributed import LocalCluster, Client, performance_report
from IPython.display import display # Dask dashboard inside the notebook itself:
import pickle

# -----------------------------------------------------------
# Robust downloader
# -----------------------------------------------------------
import urllib.error

def get_usgs_streamflow(site, start_date, end_date, min_end_date="2024-12-31"):
    """
    Download daily streamflow data from USGS NWIS for a given site and date range.
    Assumes columns '20d' (date) and '14n' (flow in cfs).
    
    Returns:
        pd.DataFrame or None if download fails or structure is unexpected
    """
    url = (
        "https://waterservices.usgs.gov/nwis/dv/"
        f"?format=rdb&sites={site}&startDT={start_date}&endDT={end_date}"
        "&parameterCd=00060&siteStatus=all"
    )

    try:
        df = pd.read_csv(url, comment="#", sep="\t", header=1, parse_dates=["20d"])
    except Exception as e:
        print(f"[{site}] failed to download: {e}; skipping")
        return None

    if "14n" not in df.columns or "20d" not in df.columns:
        print(f"[{site}] missing expected columns '20d' and '14n'; skipping")
        return None

    df = df.rename(columns={"14n": "streamflow_cfs", "20d": "date"})
    df["streamflow_cfs"] = pd.to_numeric(df["streamflow_cfs"], errors="coerce")

    # Remove rows with NaNs
    df = df.dropna(subset=["streamflow_cfs"])
    if df.empty:
        print(f"[{site}] all streamflow data missing or invalid; skipping")
        return None

    # Check time coverage
    if pd.to_datetime(df["date"].max()) < pd.to_datetime(min_end_date):
        print(f"[{site}] data ends at {df['date'].max()}, < {min_end_date}; skipping")
        return None

    # Convert to cubic meters per second (cms)
    df["streamflow_cms"] = df["streamflow_cfs"] * 0.0283168
    df = df[["date", "streamflow_cms"]].set_index("date").sort_index()

    return df
# # Example usage:
# # site_id = gaugeID  # Example gauge ID
# site_id = '09085000'

# start = '2015-01-01'
# end = '2024-12-31'

# streamflow_data = get_usgs_streamflow(site_id, start, end)
# print(streamflow_data.tail())

def get_or_download_streamflows(df, start_date="2015-01-01", end_date="2024-12-31"):
    streamflow_file = FINAL_OUT / "streamflows.pkl"
    skipped_file    = FINAL_OUT / "skipped_gauges.txt"

    if streamflow_file.exists() and skipped_file.exists():
        print("🔁 Loading cached streamflows and skipped gauges...")
        with open(streamflow_file, "rb") as f:
            streamflows = pickle.load(f)
        with open(skipped_file, "r") as f:
            skipped_gauges = [line.strip() for line in f]
    else:
        print("⬇️  Downloading streamflows from USGS...")
        streamflows = {}
        skipped_gauges = []
        gauge_ids = df["gauge_id"].str.split("_").str[-1].tolist()

        for g in gauge_ids:
            dfQ = get_usgs_streamflow(g, start_date, end_date)
            if dfQ is None:
                skipped_gauges.append(g)
            else:
                streamflows[g] = dfQ

        # Save results
        FINAL_OUT.mkdir(exist_ok=True)
        with open(streamflow_file, "wb") as f:
            pickle.dump(streamflows, f)
        with open(skipped_file, "w") as f:
            f.write("\n".join(skipped_gauges))
        print(f"✅ Saved streamflows to {streamflow_file}")
        print(f"❌ Saved skipped gauges to {skipped_file}")
        
    return streamflows, skipped_gauges


# CAMELS basins
df=gpd.read_file('/Projects/HydroMet/currierw/Caravan-Jan25-csv/shapefiles/camels/camels_basin_shapes.shp')

In [None]:
import os, multiprocessing, torch, psutil

print("Logical CPU cores :", multiprocessing.cpu_count())

if torch.cuda.is_available():
    print("CUDA device      :", torch.cuda.get_device_name(0))
    print("GPU capability   :", torch.cuda.get_device_capability(0))
else:
    print("No CUDA‑capable GPU detected")

print("RAM (GB total)   :", round(psutil.virtual_memory().total / 1e9, 1))

In [None]:
# -----------------------------------------------------------
# ❷  Constants & helpers
# -----------------------------------------------------------
BASE_OBS  = Path('/Projects/HydroMet/currierw/ERA5_LAND')
BASE_FCST = Path('/Projects/HydroMet/currierw/HRES')
SCRATCH   = Path('/Projects/HydroMet/currierw/HRES_processed_tmp')  # will be recreated
FINAL_OUT = Path('/Projects/HydroMet/currierw/HRES_processed')

FORECAST_BLOCKS = {
    "train":      pd.date_range('2016-01-01', '2020-09-30', freq='5D'),
    "validation": pd.date_range('2020-10-01', '2022-09-30', freq='5D'),
    "test":       pd.date_range('2022-10-01', '2024-09-30', freq='5D'),
}

REQUIRED_KEYS = ['precip', 'temp', 'net_solar', 'flow', 'target']
EXPECTED_LEN  = 106          # enforce the length we know is correct

# ERA5_ZARR = None
# HRES_ZARR = None

print("Opening Zarr datasets once and broadcasting to workers...")
ERA5_ZARR = xr.open_zarr(BASE_OBS / 'camels_rechunked.zarr', consolidated=True, chunks={})
HRES_ZARR = xr.open_zarr(BASE_FCST / 'camels_rechunked.zarr', consolidated=True, decode_timedelta=True, chunks={})

def standardize_tensor(arr, mean, std):
    return (arr - mean) / std

# -----------------------------------------------------------
# ❸  Per–gauge worker
# -----------------------------------------------------------
# @delayed
import time
def process_block(gauge_id, df_streamflow, split, fcst_dates, ERA5_ZARR, HRES_ZARR):
    t0 = time.time()

    out_files = []
    try:
        ds_obs = ERA5_ZARR.sel(basin=f'camels_{gauge_id}')
        ds_fcst = HRES_ZARR.sel(basin=f'camels_{gauge_id}')

        ds_obs_p = ds_obs['era5land_total_precipitation'].sel(date=slice('2015','2024-09-30'))
        ds_obs_t = ds_obs['era5land_temperature_2m'].sel(date=slice('2015','2024-09-30'))
        ds_obs_s = ds_obs['era5land_surface_net_solar_radiation'].sel(date=slice('2015','2024-09-30'))

        samples = []
        for fcst_date in fcst_dates:
            try:
                start_weekly = fcst_date - pd.Timedelta(days=305)
                end_weekly = fcst_date - pd.Timedelta(days=60) - pd.Timedelta(days=1)
                start_daily = fcst_date - pd.Timedelta(days=60)
                end_daily = fcst_date - pd.Timedelta(days=1)
                start_fore = fcst_date
                end_fore = fcst_date + pd.Timedelta(days=10)

                q_weekly = df_streamflow.loc[start_weekly:end_weekly]['streamflow_cms'].resample('7D').mean()
                q_daily  = df_streamflow.loc[start_daily:end_daily]['streamflow_cms']
                q_fore   = df_streamflow.loc[start_fore:end_fore]['streamflow_cms']
                q_combined = pd.concat([q_weekly, q_daily, q_fore]).to_xarray()
                q_combined.name = 'streamflow'

                obs_weekly_p = ds_obs_p.sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean()
                obs_weekly_t = ds_obs_t.sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean()
                obs_weekly_s = ds_obs_s.sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean()

                obs_daily_p  = ds_obs_p.sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1)))
                obs_daily_t  = ds_obs_t.sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1)))
                obs_daily_s  = ds_obs_s.sel(date=slice(start_daily, end_daily +  pd.Timedelta(days=1)))

                tmp  = ds_fcst.sel(date=fcst_date, method='nearest')
                fcst_dates_expand = pd.Timestamp(tmp.date.values) + pd.to_timedelta(tmp.lead_time)
                tmp  = tmp.assign_coords(date=('lead_time', fcst_dates_expand))
                fcst = (tmp.swap_dims({'lead_time':'date'}).drop_vars('lead_time').isel(date=slice(0,10)))
                fcst_p = fcst['hres_total_precipitation']
                fcst_t = fcst['hres_temperature_2m']
                fcst_s = fcst['hres_surface_net_solar_radiation']

                precip = xr.concat([obs_weekly_p, obs_daily_p, fcst_p], dim='date')
                temp   = xr.concat([obs_weekly_t, obs_daily_t, fcst_t], dim='date')
                nsrad  = xr.concat([obs_weekly_s, obs_daily_s, fcst_s], dim='date')

                if precip.shape[0] != EXPECTED_LEN or q_combined.shape[0] != EXPECTED_LEN:
                    continue

                flags = np.concatenate([
                    np.full(obs_weekly_p.date.size, 0),
                    np.full(obs_daily_p.date.size, 1),
                    np.full(fcst_p.date.size, 2)
                ])

                sample = {
                    'precip': precip.values.astype(np.float32),
                    'temp':   temp.values.astype(np.float32),
                    'net_solar': nsrad.values.astype(np.float32),
                    'flag':   flags.astype(np.int8),
                    'flow':   q_combined.values.astype(np.float32),
                    'target': q_combined.values.astype(np.float32),
                    'basin_id': gauge_id,
                    'forecast_date': fcst_date.strftime('%Y-%m-%d')
                }
                samples.append(sample)

            except Exception as e:
                print(f"[{gauge_id}] skip {fcst_date:%Y-%m-%d}: {e}")
        print(f"[{gauge_id}] block processed in {time.time() - t0:.2f}s")

        if samples:
            ds = samples_to_xarray(samples)
            outfile = SCRATCH / f'{split}_{gauge_id}_{fcst_dates[0].strftime("%Y%m%d")}.nc'
            ds.to_netcdf(outfile)
            out_files.append(str(outfile))
    except Exception as e:
        print(f"[{gauge_id}] failed with error: {e}")
    return out_files



# -----------------------------------------------------------
# ❹  Small helper: samples → xarray (single gauge, single split)
# -----------------------------------------------------------
def samples_to_xarray(samples):
    n = len(samples)
    dyn = np.zeros((n, EXPECTED_LEN, 4), np.float32)
    tgt = np.zeros((n, EXPECTED_LEN, 1), np.float32)
    bas = np.empty(n, 'U20')
    fct = np.empty(n, 'U20')

    for i, s in enumerate(samples):
        dyn[i,:,0] = s['precip']
        dyn[i,:,1] = s['temp']
        dyn[i,:,2] = s['net_solar']
        dyn_inputs[i, :, 3] = s['flag'].astype(np.float32)  # must be a float for LSTM
        tgt[i,:,0] = s['target']
        bas[i] = s['basin_id']
        fct[i] = s['forecast_date']
    ds = xr.Dataset(
        {
            "dynamic_inputs": (
                ["sample", "time", "feature"],
                dyn_inputs,
                {"feature": ["precip", "temp", "net_solar", "flag"]}
            ),
            "targets": (
                ["sample", "time", "target"],
                targets,
                {"target": ["streamflow"]}
            ),
            "basin_id": (["sample"], basin_ids),
            "forecast_date": (["sample"], forecast_dates)
        }
    )
    # Add metadata describing the meaning of flag values
    ds.attrs["flag_description"] = {
        "0": "weekly reanalysis (ERA5)",
        "1": "daily reanalysis (ERA5)",
        "2": "forecast (HRES)"
    }
    return ds
# -----------------------------------------------------------
# ❺  Parent / driver
# -----------------------------------------------------------
def main():
    # ---------- (re-)create scratch folder ----------
    if SCRATCH.exists():
        shutil.rmtree(SCRATCH)
    SCRATCH.mkdir(parents=True)

    # ---------- build gauge list ----------
    gauge_ids = df["gauge_id"].str.split("_").str[-1].tolist()
    print('have gauge list')
    # ---------- download streamflow (skip bad gauges) ----------
    streamflows, skipped_gauges = get_or_download_streamflows(df)
    import sys
    print(str(sys.getsizeof(streamflows['01013500'])*1e-6)+' mb')
    print(f"✅ {len(streamflows)} gauges ready, ❌ {len(skipped_gauges)} skipped")
    print('Got the Gauges: starting parallelization')

    # ---------- start Dask cluster ----------
    cluster = LocalCluster(n_workers=64, threads_per_worker=1,
                           processes=True, memory_limit="4GB")
    client  = Client(cluster)
    display(client)  # shows dashboard link in Jupyter

    # ---------- dispatch only valid gauges ----------
    CHUNK = 100  # try 20, 40, 60 etc.
    
    def chunks(seq, n):
        for i in range(0, len(seq), n):
            yield seq[i:i + n]
    
    futures = []
    
    some_gauges = list(streamflows.keys())[:64]  # or full list if ready

    scattered_streamflows = {
        g: client.scatter(dfQ, broadcast=False) for g, dfQ in streamflows.items()
    }


    ERA5_ZARR_scattered = client.scatter(ERA5_ZARR, broadcast=True, hash=False)
    HRES_ZARR_scattered = client.scatter(HRES_ZARR, broadcast=True, hash=False)
    # broadcast=True: ensures all workers get a copy.
    # hash=False: prevents Dask from trying to hash the big object (xarray datasets can be large and non-hashable).

    for g in some_gauges:
        dfQ = scattered_streamflows[g]
        for split, dates in FORECAST_BLOCKS.items():
            for sub_dates in chunks(dates, CHUNK):
                print(f"[submit] gauge: {g}, split: {split}, {sub_dates[0]} to {sub_dates[-1]}")
                fut = client.submit(process_block, g, dfQ, split, sub_dates, ERA5_ZARR_scattered, HRES_ZARR_scattered, pure=False)
                futures.append(fut)
    
    print("All submits done")
    with performance_report(filename="report.html"):
        client.gather(futures)

    print("Finished Dask computation")


    # ---------- concatenate split‑level files ----------
    print('concatenating split-level files from parallelization')
    for split in ['train','validation','test']:
        files = sorted(SCRATCH.glob(f'{split}_*.nc'))
        if not files: continue
        ds = xr.open_mfdataset(files, combine='nested',
                               concat_dim='sample',
                               parallel=True)
        ds.to_netcdf(
            FINAL_OUT / f"{split}_data_ERA5_HRES_CAMELS_unstandardized.nc",
            encoding={
                "dynamic_inputs": {"zlib": True, "complevel": 4},
                "targets":        {"zlib": True, "complevel": 4}
            },
        )
        print(f"[✓] wrote {split} set ({len(files)} gauges)")

    # ---------- at the very end ----------
    if skipped_gauges:
        print("Skipped gauges:", ", ".join(skipped_gauges))
        with open(FINAL_OUT / "skipped_gauges.txt", "w") as fp:
            fp.write("\n".join(skipped_gauges))
            
    print("All done.  See NetCDFs in", FINAL_OUT)


# if __name__ == "__main__":
#     main()
#     # This is only needed if you're running the script directly from a terminal or !python — not inside a notebook.
#     # This design allows the file to serve two purposes:
#     # Be run directly as a standalone program.
#     # Be imported as a library/module into another script without executing the main logic.



In [None]:
main()