In [None]:
import numpy as np
import xarray as xr
from typing import Tuple
import logging
import geopandas as gpd
import json
import time
import multiprocessing
import pandas as pd
from exactextract import exact_extract
from exactextract.raster import NumPyRasterSource
from pathlib import Path
from rich.progress import (
    Progress,
    BarColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
import psutil
from math import ceil
from multiprocessing import shared_memory
from functools import partial
import warnings
from dask.distributed import Client, LocalCluster
import shutil

In [None]:
logger = logging.getLogger(__name__)
warnings.filterwarnings(
    "ignore", 
    message="'DataFrame.swapaxes' is deprecated", 
    category=FutureWarning
)
warnings.filterwarnings(
    "ignore", 
    message="'GeoDataFrame.swapaxes' is deprecated", 
    category=FutureWarning
)
warnings.filterwarnings(
    "ignore", 
    message="A value is trying to be set on a copy of a slice from a DataFrame", 
    category=FutureWarning
)

In [None]:
def validate_time_range(dataset: xr.Dataset, 
                        start_time: str, 
                        end_time: str) -> Tuple[str, str]:
    end_time_in_dataset = dataset.time.isel(time=-1).values
    start_time_in_dataset = dataset.time.isel(time=0).values
    if np.datetime64(start_time) < start_time_in_dataset:
        logger.warning(
            f"provided start {start_time} is before the start of the dataset "\
                "{start_time_in_dataset}, selecting from "\
                "{start_time_in_dataset}"
        )
        start_time = start_time_in_dataset
    if np.datetime64(end_time) > end_time_in_dataset:
        logger.warning(
            f"provided end {end_time} is after the end of the dataset "\
                "{end_time_in_dataset}, selecting until {end_time_in_dataset}"
        )
        end_time = end_time_in_dataset
    return start_time, end_time

def clip_dataset_to_bounds(
    dataset: xr.Dataset, 
    bounds: Tuple[float, float, float, float], 
    start_time: str, 
    end_time: str
) -> xr.Dataset:
    """Clip the dataset to specified geographical bounds."""
    # check time range here in case just this function is imported and not the 
    # whole module
    start_time, end_time = validate_time_range(dataset, start_time, end_time)
    dataset = dataset.sel(
        x=slice(bounds[0], bounds[2]),
        y=slice(bounds[1], bounds[3]),
        time=slice(start_time, end_time),
    )
    logger.info("Selected time range and clipped to bounds")
    return dataset

In [None]:
def head_gdf_selection(headwater, gdb):
    head_gdf = gpd.GeoDataFrame(gdb.loc[int(headwater)])
    return head_gdf

def tail_gdf_selection(headwater, tailwater, gdb):
    tail_geom = gpd.GeoSeries(
        [gdb.loc[int(headwater)]['geometry'],
        gdb.loc[int(tailwater)]['geometry']]
    )

    d = gdb.loc[int(tailwater)]
    d['geometry'] = tail_geom[0]
    tail_gdf = gpd.GeoDataFrame(d)

    return tail_gdf

In [None]:
def remove_terminals(dict):
    newdict = {k: v for k, v in list(dict.items()) if v > 0} 
    return newdict

In [None]:
def get_cell_weights(raster, gdf, wkt):
    # Get the cell weights for each divide
    xmin = raster.x[0]
    xmax = raster.x[-1]
    ymin = raster.y[0]
    ymax = raster.y[-1]
    rastersource = NumPyRasterSource(
        raster["RAINRATE"], 
        srs_wkt=wkt, 
        xmin=xmin, 
        xmax=xmax, 
        ymin=ymin, 
        ymax=ymax
    )
    output = exact_extract(
        rastersource,
        gdf,
        ["cell_id", "coverage"],
        include_cols=["divide_id"],
        output="pandas",
    )
    return output.set_index("divide_id")

In [None]:

def get_cell_weights_parallel(gdf, input_forcings, num_partitions):
    gdf_chunks = np.array_split(gdf, num_partitions)
    wkt = gdf.crs.to_wkt()
    one_timestep = input_forcings.isel(time=0).compute()
    with multiprocessing.Pool() as pool:
        args = [(one_timestep, gdf_chunk, wkt) for gdf_chunk in gdf_chunks]
        catchments = pool.starmap(get_cell_weights, args)
    return pd.concat(catchments)

In [None]:
def get_units(dataset: xr.Dataset) -> dict:
    units = {}
    for var in dataset.data_vars:
        if dataset[var].attrs["units"]:
            units[var] = dataset[var].attrs["units"]
    return units

In [None]:
def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
    # takes a data array and calculates the start and end index for each chunk
    # based on the available memory.
    array_memory_usage = data.nbytes
    free_memory = psutil.virtual_memory().available * 0.8  # 80% of available 
                                                            # memory
    # limit the chunk to 20gb, makes things more stable
    free_memory = min(free_memory, 20 * 1024 * 1024 * 1024)
    num_chunks = ceil(array_memory_usage / free_memory)
    max_index = data.shape[0]
    stride = max_index // num_chunks
    chunk_start = range(0, max_index, stride)
    index_chunks = [(start, start + stride) for start in chunk_start]
    return index_chunks

In [None]:
def create_shared_memory(lazy_array):
    logger.debug(f"Creating shared memory size {lazy_array.nbytes/ 10**6} Mb.")
    shm = shared_memory.SharedMemory(create=True, size=lazy_array.nbytes)
    shared_array = np.ndarray(lazy_array.shape, 
                              dtype=np.float32, 
                              buffer=shm.buf)
    # if your data is not float32, xarray will do an automatic conversion here
    # which consumes a lot more memory, forcings downloaded with this tool will 
    # work
    for start, end in get_index_chunks(lazy_array):
            # copy data from lazy to shared memory one chunk at a time
            shared_array[start:end] = lazy_array[start:end]

    time, x, y = shared_array.shape
    shared_array = shared_array.reshape(time, -1)

    return shm, shared_array.shape, shared_array.dtype

In [None]:
def weighted_sum_of_cells(flat_raster: np.ndarray, 
                          cell_ids: np.ndarray, 
                          factors: np.ndarray):
    # Create an output array initialized with zeros
    # dimensions are raster[time][x*y]
    result = np.zeros(flat_raster.shape[0])
    result = np.sum(flat_raster[:, cell_ids] * factors, axis=1)
    sum_of_weights = np.sum(factors)
    result /= sum_of_weights
    return result

In [None]:
def process_chunk_shared(variable, times, shm_name, shape, dtype, chunk):
    existing_shm = shared_memory.SharedMemory(name=shm_name)
    raster = np.ndarray(shape, dtype=dtype, buffer=existing_shm.buf)
    results = []

    for catchment in chunk.index.unique():
        cell_ids = chunk.loc[catchment]["cell_id"]
        weights = chunk.loc[catchment]["coverage"]
        mean_at_timesteps = weighted_sum_of_cells(raster, cell_ids, weights)
        temp_da = xr.DataArray(
            mean_at_timesteps,
            dims=["time"],
            coords={"time": times},
            name=f"{variable}_{catchment}",
        )
        temp_da = temp_da.assign_coords(catchment=catchment)
        results.append(temp_da)
    existing_shm.close()
    return xr.concat(results, dim="catchment")

In [None]:
def add_APCP_SURFACE_to_dataset(dataset: xr.Dataset) -> xr.Dataset:
    # precip_rate is mm/s
    # cfe says input atmosphere_water__liquid_equivalent_precipitation_rate is 
    # mm/h
    # nom says prcpnonc input is mm/s
    # technically should be kg/m^2/s at 1kg = 1l it equates to mm/s
    # nom says qinsur output is m/s, hopefully qinsur is converted to mm/h by 
    # ngen
    dataset["APCP_surface"] = dataset["precip_rate"] * 3600
    dataset["APCP_surface"].attrs["units"] = "mm h^-1" # ^-1 notation copied 
                                                        # from source data
    dataset["APCP_surface"].attrs["source_note"] = "This is just the "\
        "precip_rate variable converted to mm/h by multiplying by 3600"
    return dataset

In [None]:
def write_outputs(forcings_dir, variables, units):

    # start a dask cluster if there isn't one already running
    try:
        client = Client.current()
    except ValueError:
        cluster = LocalCluster()
        client = Client(cluster)
    temp_forcings_dir = forcings_dir / "temp"
    # Combine all variables into a single dataset using dask
    results = [xr.open_dataset(file, chunks="auto") 
               for file in temp_forcings_dir.glob("*.nc")]
    final_ds = xr.merge(results)
    for var in final_ds.data_vars:
        if var in units:
            final_ds[var].attrs["units"] = units[var]
        else:
            logger.warning(f"Variable {var} has no units")

    rename_dict = {}
    for key, value in variables.items():
        if key in final_ds:
            rename_dict[key] = value

    final_ds = final_ds.rename_vars(rename_dict)
    final_ds = add_APCP_SURFACE_to_dataset(final_ds)

    # this step halves the storage size of the forcings
    for var in final_ds.data_vars:
        final_ds[var] = final_ds[var].astype(np.float32)

    logger.info("Saving to disk")
    # The format for the netcdf is to support a legacy format
    # which is why it's a little "unorthodox"
    # There are no coordinates, just dimensions, catchment ids are stored in a 
    # 1d data var and time is stored in a 2d data var with the same time array 
    # for every catchment
    # time is stored as unix timestamps, units have to be set

    # add the catchment ids as a 1d data var
    final_ds["ids"] = final_ds["catchment"].astype(str)
    # time needs to be a 2d array of the same time array as unix timestamps for 
    # every catchment
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        time_array = (
            final_ds.time.astype("datetime64[s]").astype(np.int64).values//10**9
        )  ## convert from ns to s
    time_array = time_array.astype(np.int32) ## convert to int32 to save space
    final_ds = final_ds.drop_vars(["catchment", "time"]) ## drop the original 
                                                    # time and catchment vars
    final_ds = final_ds.rename_dims({"catchment": "catchment-id"}) # rename the 
                                                        # catchment dimension
    # add the time as a 2d data var, yes this is wasting disk space.
    final_ds["Time"] = (("catchment-id", "time"), 
                        [time_array for _ in range(len(final_ds["ids"]))])
    # set the time unit
    final_ds["Time"].attrs["units"] = "s"
    final_ds["Time"].attrs["epoch_start"] = "01/01/1970 00:00:00" # not needed 
                                            # but suppresses the ngen warning

    final_ds.to_netcdf(forcings_dir / "forcings.nc", engine="netcdf4")
    # close the datasets
    _ = [result.close() for result in results]
    final_ds.close()

    # clean up the temp files
    for file in temp_forcings_dir.glob("*.*"):
        file.unlink()
    temp_forcings_dir.rmdir()

In [None]:
def compute_zonal_stats(
    gdf: gpd.GeoDataFrame, merged_data: xr.Dataset, forcings_dir: Path
) -> None:
    logger.info("Computing zonal stats in parallel for all timesteps")
    timer_start = time.time()
    num_partitions = multiprocessing.cpu_count() - 1
    if num_partitions > len(gdf):
        num_partitions = len(gdf)

    catchments = get_cell_weights_parallel(gdf, merged_data, num_partitions)

    units = get_units(merged_data)

    variables = {
                "LWDOWN": "DLWRF_surface",
                "PSFC": "PRES_surface",
                "Q2D": "SPFH_2maboveground",
                "RAINRATE": "precip_rate",
                "SWDOWN": "DSWRF_surface",
                "T2D": "TMP_2maboveground",
                "U2D": "UGRD_10maboveground",
                "V2D": "VGRD_10maboveground",
            }

    cat_chunks = np.array_split(catchments, num_partitions)

    progress = Progress(
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        "[progress.percentage]{task.percentage:>3.0f}%",
        TextColumn("{task.completed}/{task.total}"),
        "â€¢",
        TextColumn(" Elapsed Time:"),
        TimeElapsedColumn(),
        TextColumn(" Remaining Time:"),
        TimeRemainingColumn(),
    )

    timer = time.perf_counter()
    variable_task = progress.add_task(
        "[cyan]Processing variables...", total=len(variables), elapsed=0
    )
    progress.start()
    for variable in variables.keys():
        progress.update(variable_task, advance=1)
        progress.update(variable_task, description=f"Processing {variable}")

        if variable not in merged_data.data_vars:
            logger.warning(f"Variable {variable} not in forcings, skipping")
            continue

        # to make sure this fits in memory, we need to chunk the data
        time_chunks = get_index_chunks(merged_data[variable])
        chunk_task = progress.add_task("[purple] processing chunks", 
                                       total=len(time_chunks))
        for i, times in enumerate(time_chunks):
            progress.update(chunk_task, advance=1)
            start, end = times
            # select the chunk of time we want to process
            data_chunk = merged_data[variable].isel(time=slice(start,end))
            # put it in shared memory
            shm, shape, dtype = create_shared_memory(data_chunk)
            times = data_chunk.time.values
            # create a partial function to pass to the multiprocessing pool
            partial_process_chunk = partial(process_chunk_shared,
                                            variable,
                                            times,
                                            shm.name,
                                            shape,
                                            dtype)

            logger.debug(f"Processing variable: {variable}")
            # process the chunks of catchments in parallel
            with multiprocessing.Pool(num_partitions) as pool:
                variable_data = pool.map(partial_process_chunk, cat_chunks)
            del partial_process_chunk
            # clean up the shared memory
            shm.close()
            shm.unlink()
            logger.debug(f"Processed variable: {variable}")
            concatenated_da = xr.concat(variable_data, dim="catchment")
            # delete the data to free up memory
            del variable_data
            logger.debug(f"Concatenated variable: {variable}")
            # write this to disk now to save memory
            # xarray will monitor memory usage, but it doesn't account for the 
            # shared memory used to store the raster
            # This reduces memory usage by about 60%
            concatenated_da.to_dataset(name=variable).to_netcdf(forcings_dir/ 
                                                                "temp" / 
                                                                f"{variable}_{i}.nc")
        # Merge the chunks back together
        datasets = [xr.open_dataset(forcings_dir 
                                    / "temp" 
                                    / f"{variable}_{i}.nc") 
                                    for i in range(len(time_chunks))]
        result = xr.concat(datasets, dim="time")
        result.to_netcdf(forcings_dir / "temp" / f"{variable}.nc")
        # close the datasets
        result.close()
        _ = [dataset.close() for dataset in datasets]

        for file in forcings_dir.glob("temp/*_*.nc"):
            file.unlink()
        progress.remove_task(chunk_task)
    progress.update(
        variable_task,
        description=f"Forcings processed in {time.perf_counter() - timer:2f}"/
        " seconds",
    )
    progress.stop()
    logger.info(
        f"Forcing generation complete! Zonal stats computed in "\
            "{time.time() - timer_start:2f} seconds"
    )
    write_outputs(forcings_dir, variables, units)



In [None]:
forc_path = 'path/to/example/file'
projection = xr.open_dataset(forc_path, engine="h5netcdf").crs.esri_pe_string
logging.debug("Got projection from grid file")

In [None]:
gdb_path = "../NWM_v3_hydrofabric.gdb/"
gdb = gpd.read_file(gdb_path, driver="FileGDB", layer="nwm_catchments_conus")
start_time = "2008-01-01 01:00"
end_time = "2008-01-02 01:00"

In [None]:
al_forcs = clip_dataset_to_bounds(
    # clip nc forcing files to al bounds
    # open forcing ncs with globbed filepath, open mf dataset
    # -88.4745951503515,30.222501133601334,-84.89247974539745,35.008322669916694
)

In [None]:
small_dict = {"445308": 445314,
              "445322": 445326,
              "445328": 445336}
# headwater: tailwater

In [None]:
def process_dict(dict, gdb, forcs, start_time, end_time, output_file):
    for k, v in list(dict.items()):
        head_gdf = head_gdf_selection(k, gdb)
        tail_gdf = tail_gdf_selection(v, gdb)

        head_forcs = clip_dataset_to_bounds(
            forcs, head_gdf.total_bounds, start_time, end_time
        )
        logging.debug(f"head gdf bounds: {head_gdf.total_bounds}")

        forcing_working_dir = "test/"+k+"-working-dir"

        if not forcing_working_dir.exists():
            forcing_working_dir.mkdir(parents=True, exist_ok=True)

        temp_dir = forcing_working_dir / "temp"
        if not temp_dir.exists():
            temp_dir.mkdir(parents=True, exist_ok=True)
        
        compute_zonal_stats(head_gdf, head_forcs, forcing_working_dir)
            
        shutil.copy(forcing_working_dir / "forcings.nc", output_file)
        logging.info(f"Created forcings file: {output_file}")
        # remove the working directory
        shutil.rmtree(forcing_working_dir)