In [1]:
# preprocess_LSTM_dask.py
# -----------------------------------------------------------
# ❶  Imports & Dask cluster
# -----------------------------------------------------------
import torch
import zarr
import xarray as xr
import geopandas as gpd
import pandas as pd
import numpy as np
import os, shutil, json, tempfile
from pathlib import Path
from collections import Counter, defaultdict
from dask import delayed, compute
from dask.distributed import LocalCluster, Client, performance_report
from IPython.display import display # Dask dashboard inside the notebook itself:

# -----------------------------------------------------------
# Robust downloader
# -----------------------------------------------------------
import urllib.error

def get_usgs_streamflow(site, start_date, end_date, min_end_date="2024-12-31"):
    """
    Download daily streamflow data from USGS NWIS for a given site and date range.
    Assumes columns '20d' (date) and '14n' (flow in cfs).
    
    Returns:
        pd.DataFrame or None if download fails or structure is unexpected
    """
    url = (
        "https://waterservices.usgs.gov/nwis/dv/"
        f"?format=rdb&sites={site}&startDT={start_date}&endDT={end_date}"
        "&parameterCd=00060&siteStatus=all"
    )

    try:
        df = pd.read_csv(url, comment="#", sep="\t", header=1, parse_dates=["20d"])
    except Exception as e:
        print(f"[{site}] failed to download: {e}; skipping")
        return None

    if "14n" not in df.columns or "20d" not in df.columns:
        print(f"[{site}] missing expected columns '20d' and '14n'; skipping")
        return None

    df = df.rename(columns={"14n": "streamflow_cfs", "20d": "date"})
    df["streamflow_cfs"] = pd.to_numeric(df["streamflow_cfs"], errors="coerce")

    # Remove rows with NaNs
    df = df.dropna(subset=["streamflow_cfs"])
    if df.empty:
        print(f"[{site}] all streamflow data missing or invalid; skipping")
        return None

    # Check time coverage
    if pd.to_datetime(df["date"].max()) < pd.to_datetime(min_end_date):
        print(f"[{site}] data ends at {df['date'].max()}, < {min_end_date}; skipping")
        return None

    # Convert to cubic meters per second (cms)
    df["streamflow_cms"] = df["streamflow_cfs"] * 0.0283168
    df = df[["date", "streamflow_cms"]].set_index("date").sort_index()

    return df
# # Example usage:
# # site_id = gaugeID  # Example gauge ID
# site_id = '09085000'

# start = '2015-01-01'
# end = '2024-12-31'

# streamflow_data = get_usgs_streamflow(site_id, start, end)
# print(streamflow_data.tail())

# CAMELS basins
df=gpd.read_file('/Projects/HydroMet/currierw/Caravan-Jan25-csv/shapefiles/camels/camels_basin_shapes.shp')

In [2]:
import os, multiprocessing, torch, psutil

print("Logical CPU cores :", multiprocessing.cpu_count())

if torch.cuda.is_available():
    print("CUDA device      :", torch.cuda.get_device_name(0))
    print("GPU capability   :", torch.cuda.get_device_capability(0))
else:
    print("No CUDA‑capable GPU detected")

print("RAM (GB total)   :", round(psutil.virtual_memory().total / 1e9, 1))

Logical CPU cores : 256
CUDA device      : NVIDIA A100-SXM4-40GB
GPU capability   : (8, 0)
RAM (GB total)   : 1082.0


In [3]:
# -----------------------------------------------------------
# ❷  Constants & helpers
# -----------------------------------------------------------
BASE_OBS  = Path('/Projects/HydroMet/currierw/ERA5_LAND')
BASE_FCST = Path('/Projects/HydroMet/currierw/HRES')
SCRATCH   = Path('/Projects/HydroMet/currierw/HRES_processed_tmp')  # will be recreated
FINAL_OUT = Path('/Projects/HydroMet/currierw/HRES_processed')

FORECAST_BLOCKS = {
    "train":      pd.date_range('2016-01-01', '2020-09-30', freq='5D'),
    "validation": pd.date_range('2020-10-01', '2022-09-30', freq='5D'),
    "test":       pd.date_range('2022-10-01', '2024-09-30', freq='5D'),
}

REQUIRED_KEYS = ['precip', 'temp', 'net_solar', 'flow', 'target']
EXPECTED_LEN  = 115          # enforce the length we know is correct

def standardize_tensor(arr, mean, std):
    return (arr - mean) / std

# -----------------------------------------------------------
# ❸  Per–gauge worker
# -----------------------------------------------------------
@delayed
def process_gauge(gauge_id, df_streamflow):
    print(f"[{gauge_id}] Starting gauge")
    """
    Reads ERA5‑LAND + HRES forecasts for one gauge,
    builds samples, writes temporary NetCDF files,
    and returns running statistics for scaling.
    """
    stats = {k: {'sum': 0.0, 'sumsq': 0.0, 'n': 0} for k in REQUIRED_KEYS}
    out_files = []

    try:
        ds_obs = xr.open_zarr(BASE_OBS / 'timeseries.zarr',
                              consolidated=True,
                              chunks={}).sel(basin=f'camels_{gauge_id}')

        ds_fcst = xr.open_zarr(BASE_FCST / 'timeseries.zarr',
                               consolidated=True,
                               decode_timedelta=True,
                               chunks={}).sel(basin=f'camels_{gauge_id}')
    except FileNotFoundError:
        print(f"[{gauge_id}] missing Zarr store – skipping")
        return out_files, stats   # empty

    # Slice obs once (cheap view)
    ds_obs_p = ds_obs['era5land_total_precipitation'].sel(date=slice('2015','2024-09-30'))
    ds_obs_t = ds_obs['era5land_temperature_2m'].sel(date=slice('2015','2024-09-30'))
    ds_obs_s = ds_obs['era5land_surface_net_solar_radiation'].sel(date=slice('2015','2024-09-30'))

    for split, fcst_dates in FORECAST_BLOCKS.items():
        samples = []
        for fcst_date in fcst_dates:
            try:
                # Windows
                start_weekly  = fcst_date - pd.Timedelta(days=294)
                end_weekly    = fcst_date - pd.Timedelta(days=60) - pd.Timedelta(days=1)
                start_daily   = fcst_date - pd.Timedelta(days=60)
                end_daily     = fcst_date - pd.Timedelta(days=1)
                start_fore    = fcst_date
                end_fore      = fcst_date + pd.Timedelta(days=10)

                # Streamflow (caller passes df_streamflow to avoid re‑downloading)
                q_weekly = (df_streamflow.loc[start_weekly:end_weekly]
                            ['streamflow_cms']
                            .resample('W-SUN', label='left', closed='left')
                            .mean())
                q_daily  = df_streamflow.loc[start_daily:end_daily]['streamflow_cms']
                q_fore   = df_streamflow.loc[start_fore:end_fore]['streamflow_cms']
                q_combined = pd.concat([q_weekly, q_daily, q_fore]).to_xarray()
                q_combined.name = 'streamflow'

                # ERA5 weekly / daily
                obs_weekly_p = ds_obs_p.sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean()
                obs_weekly_t = ds_obs_t.sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean()
                obs_weekly_s = ds_obs_s.sel(date=slice(start_weekly, end_weekly)).resample(date='7D').mean()

                obs_daily_p  = ds_obs_p.sel(date=slice(start_daily, end_daily))
                obs_daily_t  = ds_obs_t.sel(date=slice(start_daily, end_daily))
                obs_daily_s  = ds_obs_s.sel(date=slice(start_daily, end_daily))

                # HRES forecast
                tmp  = ds_fcst.sel(date=fcst_date, method='nearest')
                fcst_dates_expand = pd.Timestamp(tmp.date.values) + pd.to_timedelta(tmp.lead_time)
                tmp  = tmp.assign_coords(date=('lead_time', fcst_dates_expand))
                fcst = (tmp.swap_dims({'lead_time':'date'})
                           .drop_vars('lead_time')
                           .isel(date=slice(0,10)))
                fcst_p = fcst['hres_total_precipitation']
                fcst_t = fcst['hres_temperature_2m']
                fcst_s = fcst['hres_surface_net_solar_radiation']

                precip = xr.concat([obs_weekly_p, obs_daily_p, fcst_p], dim='date')
                temp   = xr.concat([obs_weekly_t, obs_daily_t, fcst_t], dim='date')
                nsrad  = xr.concat([obs_weekly_s, obs_daily_s, fcst_s], dim='date')

                if precip.shape[0] != EXPECTED_LEN:   # guard rail
                    continue
                if q_combined.shape[0] != EXPECTED_LEN:   # guard rail
                    continue
                    
                # Flags
                flags = np.concatenate([
                    np.full(obs_weekly_p.date.size, 0),
                    np.full(obs_daily_p.date.size, 1),
                    np.full(forecast_data_p.date.size, 2)
                ])
                
                sample = {
                    'precip': precip.values.astype(np.float32),
                    'temp':   temp.values.astype(np.float32),
                    'net_solar': nsrad.values.astype(np.float32),
                    'flag':   flags.astype(np.int8),
                    'flow':   q_combined.values.astype(np.float32),
                    'target': q_combined.values.astype(np.float32),
                    'basin_id': gauge_id,
                    'forecast_date': fcst_date.strftime('%Y-%m-%d')
                }
                samples.append(sample)

                # update stats incrementally  (mean later = sum/n)
                for k in REQUIRED_KEYS:
                    arr = sample[k].ravel().astype(np.float64)
                    stats[k]['sum']   += arr.sum()
                    stats[k]['sumsq'] += np.square(arr).sum()
                stats[k]['n'] += arr.size

            except Exception as e:
                print(f"[{gauge_id}] skip {fcst_date:%Y-%m-%d}: {e}")

        # write split‑level NetCDF if any samples
        if samples:
            print(f"[{gauge_id}] Writing to NetCDF...")
            ds = samples_to_xarray(samples)           # uses EXPECTED_LEN
            outfile = SCRATCH / f'{split}_{gauge_id}.nc'
            ds.to_netcdf(outfile)
            out_files.append(str(outfile))

    return out_files, stats


# -----------------------------------------------------------
# ❹  Small helper: samples → xarray (single gauge, single split)
# -----------------------------------------------------------
def samples_to_xarray(samples):
    n = len(samples)
    dyn = np.zeros((n, EXPECTED_LEN, 4), np.float32)
    tgt = np.zeros((n, EXPECTED_LEN, 1), np.float32)
    bas = np.empty(n, 'U20')
    fct = np.empty(n, 'U20')

    for i, s in enumerate(samples):
        dyn[i,:,0] = s['precip']
        dyn[i,:,1] = s['temp']
        dyn[i,:,2] = s['net_solar']
        dyn_inputs[i, :, 3] = s['flag'].astype(np.float32)  # must be a float for LSTM
        tgt[i,:,0] = s['target']
        bas[i] = s['basin_id']
        fct[i] = s['forecast_date']
    ds = xr.Dataset(
        {
            "dynamic_inputs": (
                ["sample", "time", "feature"],
                dyn_inputs,
                {"feature": ["precip", "temp", "net_solar", "flag"]}
            ),
            "targets": (
                ["sample", "time", "target"],
                targets,
                {"target": ["streamflow"]}
            ),
            "basin_id": (["sample"], basin_ids),
            "forecast_date": (["sample"], forecast_dates)
        }
    )
    # Add metadata describing the meaning of flag values
    ds.attrs["flag_description"] = {
        "0": "weekly reanalysis (ERA5)",
        "1": "daily reanalysis (ERA5)",
        "2": "forecast (HRES)"
    }
    return ds
# -----------------------------------------------------------
# ❺  Parent / driver
# -----------------------------------------------------------
def main():
    # ---------- (re-)create scratch folder ----------
    if SCRATCH.exists():
        shutil.rmtree(SCRATCH)
    SCRATCH.mkdir(parents=True)

    # ---------- build gauge list ----------
    gauge_ids = df["gauge_id"].str.split("_").str[-1].tolist()
    print('have gauge list')
    # ---------- download streamflow (skip bad gauges) ----------
    streamflows   = {}
    skipped_gauges = []

    for g in gauge_ids:
        dfQ = get_usgs_streamflow(g, "2015-01-01", "2024-12-31")
        if dfQ is None:
            skipped_gauges.append(g)
        else:
            streamflows[g] = dfQ
            
    print('Got the Gauges: starting parallelization')
    print(f"✅ {len(streamflows)} gauges ready, ❌ {len(skipped_gauges)} skipped")

    # ---------- start Dask cluster ----------
    cluster = LocalCluster(n_workers=8, threads_per_worker=2,
                           processes=True, memory_limit="8GB")
    client  = Client(cluster)
    display(client)  # shows dashboard link in Jupyter

    # ---------- dispatch only valid gauges ----------
    # Scatter the streamflows dict ahead of time:
    # Without scatter: every task pickles its df_streamflow argument → the whole DataFrame travels across the network (or process boundary) every time.
    # With scatter: you send each DataFrame once to each worker; tasks get a lightweight future (a pointer). Graphs shrink, network traffic drops.
    streamflows_future = client.scatter(streamflows, broadcast=True)

    futures = [client.submit(process_gauge, g, streamflows_future[g]) for g in streamflows.keys()] # only use the gauges that weren't skipped
    # futures = [process_gauge(g, streamflows[g]) for g in streamflows.keys()]

    # Below: This is telling Dask to record a performance profile during the execution of the tasks, and save it at dask_profile.html.
    # with performance_report(filename="dask_profile.html"):
    #     results = compute(*futures)
    print('wrote report')

    # ---------- merge statistics ----------
    print('computing statistics')
    reducer = defaultdict(lambda: {'sum':0.0,'sumsq':0.0,'n':0})
    for _, stat in results:
        for k in REQUIRED_KEYS: # ['precip', 'temp', 'net_solar', 'flow', 'target']
            reducer[k]['sum']   += stat[k]['sum']
            reducer[k]['sumsq'] += stat[k]['sumsq']
            reducer[k]['n']     += stat[k]['n']

    scaler = {}
    for k,v in reducer.items():
        mean = v['sum'] / v['n']
        var  = v['sumsq']/v['n'] - mean**2
        scaler[k] = (float(mean), float(np.sqrt(var)))

    (FINAL_OUT).mkdir(exist_ok=True) # checks path
    json.dump(scaler, open(FINAL_OUT/'scaler.json','w'), indent=2) # dump the scalers that we can use to normalize

    # ---------- concatenate split‑level files ----------
    print('concatenating split-level files from parallelization')
    for split in ['train','validation','test']:
        files = sorted(SCRATCH.glob(f'{split}_*.nc'))
        if not files: continue
        ds = xr.open_mfdataset(files, combine='nested',
                               concat_dim='sample',
                               parallel=True)
        ds.to_netcdf(
            FINAL_OUT / f"{split}_data_ERA5_HRES_CAMELS_unstandardized.nc",
            encoding={
                "dynamic_inputs": {"zlib": True, "complevel": 4},
                "targets":        {"zlib": True, "complevel": 4}
            },
        )
        print(f"[✓] wrote {split} set ({len(files)} gauges)")

    # ---------- at the very end ----------
    if skipped_gauges:
        print("Skipped gauges:", ", ".join(skipped_gauges))
        with open(FINAL_OUT / "skipped_gauges.txt", "w") as fp:
            fp.write("\n".join(skipped_gauges))
            
    print("All done.  See scaler.json and NetCDFs in", FINAL_OUT)


# if __name__ == "__main__":
#     main()
#     # This is only needed if you're running the script directly from a terminal or !python — not inside a notebook.
#     # This design allows the file to serve two purposes:
#     # Be run directly as a standalone program.
#     # Be imported as a library/module into another script without executing the main logic.



In [4]:
main()

have gauge list
[01142500] data ends at 2024-12-22 00:00:00, < 2024-12-31; skipping
[01415000] data ends at 2024-12-21 00:00:00, < 2024-12-31; skipping
[01594950] data ends at 2020-07-01 00:00:00, < 2024-12-31; skipping
[02053200] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02077200] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02081500] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02082950] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02092500] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02096846] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02102908] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02111180] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02111500] data ends at 2019-09-30 00:00:00, < 2024-12-31; skipping
[02112120] missing expected columns '20d' and '14n'; skipping
[02112360] missing expected columns '20d' and '14n'; skipping
[02118500] data ends at 2019-09-30 00:00:00,

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 16,Total memory: 59.60 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:35487,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 16
Started: Just now,Total memory: 59.60 GiB

0,1
Comm: tcp://127.0.0.1:46863,Total threads: 2
Dashboard: http://127.0.0.1:41021/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:45287,
Local directory: /tmp/dask-scratch-space-4201/worker-6ovrg1ha,Local directory: /tmp/dask-scratch-space-4201/worker-6ovrg1ha

0,1
Comm: tcp://127.0.0.1:38089,Total threads: 2
Dashboard: http://127.0.0.1:34973/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:33291,
Local directory: /tmp/dask-scratch-space-4201/worker-ajwo5_zg,Local directory: /tmp/dask-scratch-space-4201/worker-ajwo5_zg

0,1
Comm: tcp://127.0.0.1:34245,Total threads: 2
Dashboard: http://127.0.0.1:39357/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:45125,
Local directory: /tmp/dask-scratch-space-4201/worker-0sowxadv,Local directory: /tmp/dask-scratch-space-4201/worker-0sowxadv

0,1
Comm: tcp://127.0.0.1:46363,Total threads: 2
Dashboard: http://127.0.0.1:38247/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:42173,
Local directory: /tmp/dask-scratch-space-4201/worker-qfygz7ee,Local directory: /tmp/dask-scratch-space-4201/worker-qfygz7ee

0,1
Comm: tcp://127.0.0.1:46177,Total threads: 2
Dashboard: http://127.0.0.1:35689/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:38795,
Local directory: /tmp/dask-scratch-space-4201/worker-95vjkpmc,Local directory: /tmp/dask-scratch-space-4201/worker-95vjkpmc

0,1
Comm: tcp://127.0.0.1:38813,Total threads: 2
Dashboard: http://127.0.0.1:39367/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:36369,
Local directory: /tmp/dask-scratch-space-4201/worker-_5v5cd6o,Local directory: /tmp/dask-scratch-space-4201/worker-_5v5cd6o

0,1
Comm: tcp://127.0.0.1:35407,Total threads: 2
Dashboard: http://127.0.0.1:41643/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:38365,
Local directory: /tmp/dask-scratch-space-4201/worker-5xnup6o7,Local directory: /tmp/dask-scratch-space-4201/worker-5xnup6o7

0,1
Comm: tcp://127.0.0.1:38725,Total threads: 2
Dashboard: http://127.0.0.1:40917/status,Memory: 7.45 GiB
Nanny: tcp://127.0.0.1:35129,
Local directory: /tmp/dask-scratch-space-4201/worker-w_q5g1jo,Local directory: /tmp/dask-scratch-space-4201/worker-w_q5g1jo


2025-07-02 13:49:20,798 - distributed.core - ERROR - Exception while handling op performance_report
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/distributed/core.py", line 970, in _handle_comm
    result = await result
  File "/usr/local/lib/python3.10/dist-packages/distributed/scheduler.py", line 7924, in performance_report
    compute, scheduler, workers = map(
  File "/usr/local/lib/python3.10/dist-packages/distributed/scheduler.py", line 7921, in profile_to_figure
    figure, source = profile.plot_figure(data, sizing_mode="stretch_both")
  File "/usr/local/lib/python3.10/dist-packages/distributed/profile.py", line 482, in plot_figure
    from bokeh.models import HoverTool
ModuleNotFoundError: No module named 'bokeh'


ModuleNotFoundError: No module named 'bokeh'

In [None]:
# from collections import Counter

# def to_xarray_dataset(samples, standardize=False, scaler=None):
#     # Step 1: Count all time lengths (based on 'precip')
#     lengths = [s['precip'].shape[0] for s in samples]
#     length_counts = Counter(lengths)

#     # Step 2: Infer most common sequence length
#     EXPECTED_LEN, _ = length_counts.most_common(1)[0]

#     REQUIRED_KEYS = ['precip', 'temp', 'net_solar', 'target', 'flag']

#     # 🔍 Optional: Print samples with mismatched array lengths
#     for s in samples:
#         lengths = {k: s[k].shape[0] for k in REQUIRED_KEYS}
#         if len(set(lengths.values())) > 1:
#             print(f"⚠️ Mismatched lengths for sample {s['forecast_date']} / {s['basin_id']}: {lengths}")

#     # Step 3: Keep only samples where ALL arrays match EXPECTED_LEN
#     clean_samples = [
#         s for s in samples
#         if all(s[k].shape[0] == EXPECTED_LEN for k in REQUIRED_KEYS)
#     ]

#     dropped = len(samples) - len(clean_samples)
#     if dropped > 0:
#         print(f"⚠️ Dropped {dropped} of {len(samples)} samples due to unexpected time length.")

#     if not clean_samples:
#         raise ValueError("No valid samples with consistent length")

#     # ✅ Missing before — now added:
#     n_samples = len(clean_samples)
#     n_time = EXPECTED_LEN
#     dyn_inputs = np.zeros((n_samples, n_time, 3), dtype=np.float32)
#     targets = np.zeros((n_samples, n_time, 1), dtype=np.float32)
#     basin_ids = np.empty(n_samples, dtype='U20')
#     forecast_dates = np.empty(n_samples, dtype='U20')

#     for i, s in enumerate(clean_samples):
#         p = s['precip']
#         t2 = s['temp']
#         ns = s['net_solar']
#         f = s['flag']
#         t = s['target']

#         if standardize:
#             p  = standardize_tensor(p, *scaler['precip'])
#             t2 = standardize_tensor(t2, *scaler['temp'])
#             ns = standardize_tensor(ns, *scaler['net_solar'])
#             f = standardize_tensor(ns, *scaler['flag'])
#             t  = standardize_tensor(t, *scaler['target'])

#         dyn_inputs[i, :, 0] = p
#         dyn_inputs[i, :, 1] = t2
#         dyn_inputs[i, :, 2] = ns
#         dyn_inputs[i, :, 3] = flag
#         targets[i, :, 0] = t
#         basin_ids[i] = s['basin_id']
#         forecast_dates[i] = s['forecast_date']

#     return xr.Dataset(
#         {
#             "dynamic_inputs": (
#                 ["sample", "time", "feature"],
#                 dyn_inputs,
#                 {"feature": ["precip", "temp", "net_solar","flag"]}
#             ),
#             "targets": (
#                 ["sample", "time", "target"],
#                 targets,
#                 {"target": ["streamflow"]}
#             ),
#             "basin_id": (["sample"], basin_ids),
#             "forecast_date": (["sample"], forecast_dates)
#         }
#     )

In [None]:
# # Write 2 NetCDF files (train + validation) unstandardized
# for split in ["train", "validation", "test"]:
# # split="test"
#     std = False
#     ds = to_xarray_dataset(
#         dataset_store[split],
#         standardize=False,
#         scaler=scaler
#     )
#     suffix = "standardized" if std else "unstandardized"
#     fname = f"{split}_data_ERA5_HRES_CAMELS_{suffix}.nc"
#     ds.to_netcdf('/Projects/HydroMet/currierw/HRES_processed/'+fname)
#     print(f"Saved: {fname}")

In [None]:
# print(dataset_store['train'][0]['precip'].shape[0])
# print(dataset_store['train'][0]['temp'].shape[0])
# print(dataset_store['train'][0]['net_solar'].shape[0])
# print(dataset_store['train'][0]['flow'].shape[0])
# print(dataset_store['train'][0]['target'].shape[0])


In [None]:
# ds=xr.open_dataset('/Projects/HydroMet/currierw/HRES_processed/train_data_ERA5_HRES_CAMELS_unstandardized.nc')
# ds['dynamic_inputs']