In [1]:
# Import packages
import numpy as np
import xarray as xr
from typing import Tuple
import logging
import geopandas as gpd
import json
import time
import multiprocessing
import pandas as pd
from exactextract import exact_extract
from exactextract.raster import NumPyRasterSource
from pathlib import Path
from rich.progress import (
    Progress,
    BarColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
import psutil
from math import ceil
from multiprocessing import shared_memory
from functools import partial
import warnings
from dask.distributed import Client, LocalCluster
import shutil
import os
from shapely.geometry import Polygon
import datetime

In [2]:
logger = logging.getLogger(__name__)
warnings.filterwarnings(
    "ignore", 
    message="'DataFrame.swapaxes' is deprecated", 
    category=FutureWarning
)
warnings.filterwarnings(
    "ignore", 
    message="'GeoDataFrame.swapaxes' is deprecated", 
    category=FutureWarning
)

In [3]:
def validate_time_range(dataset: xr.Dataset, 
                        start_time: str, 
                        end_time: str) -> Tuple[str, str]:
    '''
    Ensure that all selected times are in the passed dataset.

    Parameters
    ----------
    dataset : xr.Dataset
        Dataset with a time coordinate.
    start_time : str
        Desired start time in YYYY/MM/DD HH:MM:SS format.
    end_time : str
        Desired end time in YYYY/MM/DD HH:MM:SS format.

    Returns
    -------
    str
        start_time, or if not available, earliest available timestep in dataset.
    str
        end_time, or if not available, latest available timestep in dataset.
    '''
    end_time_in_dataset = dataset.time.isel(time=-1).values
    start_time_in_dataset = dataset.time.isel(time=0).values
    if np.datetime64(start_time) < start_time_in_dataset:
        logger.warning(
            f"provided start {start_time} is before the start of the dataset "\
                "{start_time_in_dataset}, selecting from "\
                "{start_time_in_dataset}"
        )
        start_time = start_time_in_dataset
    if np.datetime64(end_time) > end_time_in_dataset:
        logger.warning(
            f"provided end {end_time} is after the end of the dataset "\
                "{end_time_in_dataset}, selecting until {end_time_in_dataset}"
        )
        end_time = end_time_in_dataset
    return start_time, end_time

def clip_dataset_to_bounds(dataset: xr.Dataset, 
                           bounds: Tuple[float, float, float, float], 
                           start_time: str, 
                           end_time: str) -> xr.Dataset:
    """
    Clip the dataset to specified geographical bounds.

    Parameters
    ----------
    dataset : xr.Dataset
        Dataset to be clipped.
    bounds : tuple[float, float, float, float]
        Corners of bounding box. bounds[0] is x_min, bounds[1] is y_min, 
        bounds[2] is x_max, bounds[3] is y_max.
    start_time : str
        Desired start time in YYYY/MM/DD HH:MM:SS format.
    end_time : str
        Desired end time in YYYY/MM/DD HH:MM:SS format.
    
    Returns
    -------
    xr.Dataset
        Clipped dataset.
    """
    # check time range here in case just this function is imported and not the 
    # whole module
    start_time, end_time = validate_time_range(dataset, start_time, end_time)
    dataset = dataset.sel(
        x=slice(bounds[0]-1000, bounds[2]+1000),
        y=slice(bounds[1]-1000, bounds[3]+1000),
        time=slice(start_time, end_time),
    )
    logger.info("Selected time range and clipped to bounds")
    return dataset

In [4]:
def head_gdf_selection(headwater: str, 
                       gdb: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    '''  
    Select headwater row from GeoDataFrame containing all basins in study area.

    Parameters
    ----------
    headwater : str
        NWM 3.0 reach ID of headwater basin
    gdb : gpd.GeoDataFrame
        GeoDataFrame that contains geometry information about all basins in
        study area.

    Returns
    -------
    head_gdf : gpd.GeoDataFrame
        The row in gdb that corresponds to the headwater basin.
    '''
    head_gdf = gdb.loc[gdb['ID'] == int(headwater)]
    #print(head_gdf['geometry'])
    return head_gdf

def tail_gdf_selection(headwater: str, 
                       tailwater: int, 
                       gdb: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    '''  
    Select tailwater row from GeoDataFrame containing all basins in study area.
    The returned geometry is the union of the polygons for the headwater basin
    and the tailwater basin.

    Parameters
    ----------
    headwater : str
        NWM 3.0 reach ID of headwater basin
    tailwater : str
        NWM 3.0 reach ID of tailwater basin
    gdb : gpd.GeoDataFrame
        GeoDataFrame that contains geometry information about all basins in
        study area.

    Returns
    -------
    tail_gdf : gpd.GeoDataFrame
        The row in gdb that corresponds to the tailwater basin with a merged
        geometry that encompasses both the headwater and tailwater polygons.
    '''
    tail_geom = gpd.GeoSeries([
        gdb.loc[gdb['ID'] == int(headwater), 'geometry'].values[0],
        gdb.loc[gdb['ID'] == int(tailwater), 'geometry'].values[0]
    ]).union_all() 

    d = gdb.loc[gdb['ID'] == int(tailwater)]
    d.loc[:,'geometry'] = tail_geom
    tail_gdf = gpd.GeoDataFrame(d)
    #print(tail_geom)
    return tail_gdf

In [5]:
def remove_terminals(olddict: dict) -> dict:
    '''
    Remove entries in olddict whose values are negative numbers, aka the 
    tailwater basin is a terminal basin (e.g. ocean, lake, reservoir).

    Parameters
    ----------
    olddict : dict
        Dictionary of basin pairs

    Returns
    -------
    newdict : dict
        Dictionary of basin pairs that contains no terminal basins as tailwaters
    '''
    newdict = {k: v for k, v in list(olddict.items()) if v > 0} 
    return newdict

In [6]:
def get_cell_weights(raster: xr.Dataset, 
                     gdf: gpd.GeoDataFrame, 
                     wkt: str) -> pd.DataFrame:
    '''
    Get the cell weights (coverage) for each cell in a divide. Coverage is 
    defined as the fraction (a float in [0,1]) of a raster cell that overlaps 
    with the polygon in the passed gdf.

    Parameters
    ----------
    raster : xr.Dataset
        One timestep of a gridded forcings dataset.
    gdf : gpd.GeoDataFrame
        A GeoDataFrame with a polygon feature.
    wkt : str
        Well-known text (WKT) representation of gdf's coordinate reference
        system (CRS)

    Returns
    -------
    pd.DataFrame
        DataFrame indexed by divide_id that contains information about coverage
        for each raster cell in gridded forcing file.
    '''
    #print(raster)
    xmin = raster.x[0]
    xmax = raster.x[-1]
    ymin = raster.y[0]
    ymax = raster.y[-1]
    rastersource = NumPyRasterSource(
        raster["RAINRATE"], 
        srs_wkt=wkt, 
        xmin=xmin, 
        xmax=xmax, 
        ymin=ymin, 
        ymax=ymax
    )
    output = exact_extract(
        rastersource,
        gdf,
        ["cell_id", "coverage"],
        include_cols=["ID"],
        output="pandas",
    )
    return output.set_index("ID")

    # if gdf.crs != wkt:
    #     gdf = gdf.to_crs(wkt)
    
    # # Check and adjust extents
    # raster_bounds = (raster.x[0], raster.x[-1], raster.y[0], raster.y[-1])
    # gdf_bounds = gdf.total_bounds  # (minx, miny, maxx, maxy)

    # if (
    #     gdf_bounds[2] < raster_bounds[0] or  # gdf maxX < raster minX
    #     gdf_bounds[0] > raster_bounds[2] or  # gdf minX > raster maxX
    #     gdf_bounds[3] < raster_bounds[1] or  # gdf maxY < raster minY
    #     gdf_bounds[1] > raster_bounds[3]     # gdf minY > raster maxY
    # ):
    #     print("Raster extent:", raster.x[0], raster.x[-1], raster.y[0], raster.y[-1])
    #     print("Geometry extent:", gdf.total_bounds)  # minx, miny, maxx, maxy

    #     raise ValueError("Geometry extent does not overlap with raster extent!")

    # # Check if geometries are valid
    # gdf = gdf[gdf.is_valid]

    # # Proceed with extraction
    # output = exact_extract(
    #     raster["RAINRATE"],
    #     gdf,
    #     ["cell_id", "coverage"],
    #     include_cols=["ID"],
    #     output="pandas",
    # )

    # return output.set_index("ID")


In [7]:
def get_cell_weights_parallel(gdf: gpd.GeoDataFrame,
                              input_forcings: xr.Dataset,
                              num_partitions: int) -> pd.DataFrame:
    '''
    Execute get_cell_weights with multiprocessing, with chunking for the passed
    GeoDataFrame to conserve memory usage.

    Parameters
    ----------
    gdf : gpd.GeoDataFrame
        A GeoDataFrame with a polygon feature.
    input_forcings : xr.Dataset
        A gridded forcings file.
    num_partitions : int
        Number of chunks to split gdf into.

    Returns
    -------
    pd.DataFrame
        DataFrame indexed by divide_id that contains information about coverage
        for each raster cell and each timestep in gridded forcing file.
    '''
    gdf_chunks = np.array_split(gdf, num_partitions)
    wkt = gdf.crs.to_wkt()
    one_timestep = input_forcings.isel(time=0).compute()
    with multiprocessing.Pool() as pool:
        args = [(one_timestep, gdf_chunk, wkt) for gdf_chunk in gdf_chunks]
        catchments = pool.starmap(get_cell_weights, args)
    return pd.concat(catchments)

In [8]:
def get_units(dataset: xr.Dataset) -> dict:
    '''
    Return dictionary of units for each variable in dataset.
    
    Parameters
    ----------
    dataset : xr.Dataset
        Dataset with variables and units.
    
    Returns
    -------
    dict 
        Each key is a variable in dataset, and its value is the unit associated
        with the variable.
    '''
    units = {}
    for var in dataset.data_vars:
        if "units" in dataset[var].attrs:
            units[var] = dataset[var].attrs["units"]
        else:
            units[var] = "unknown"
    return units

In [9]:
def get_index_chunks(data: xr.DataArray) -> list[tuple[int, int]]:
    '''  
    Take a DataArray and calculate the start and end index for each chunk based
    on the available memory.

    Parameters
    ----------
    data : xr.DataArray
        Large DataArray that can't be loaded into memory all at once.

    Returns
    -------
    list[Tuple[int, int]]
        Each element in the list represents a chunk of data. The tuple within
        the chunk indicates the start index and end index of the chunk.
    '''
    array_memory_usage = data.nbytes
    free_memory = psutil.virtual_memory().available * 0.8  # 80% of available 
                                                            # memory
    # limit the chunk to 20gb, makes things more stable
    free_memory = min(free_memory, 20 * 1024 * 1024 * 1024)
    num_chunks = ceil(array_memory_usage / free_memory)
    max_index = data.shape[0]
    stride = max_index // num_chunks
    chunk_start = range(0, max_index, stride)
    index_chunks = [(start, start + stride) for start in chunk_start]
    return index_chunks

In [10]:
def create_shared_memory(lazy_array: xr.Dataset) -> Tuple[
    shared_memory.SharedMemory,
    np.dtype,
    np.dtype
]:
    '''
    Create a shared memory object so that multiple processes can access loaded 
    data.
    
    Parameters
    ----------
    lazy_array : xr.Dataset
        A chunk of gridded forcing variable data.

    Returns
    -------
    shared_memory.SharedMemory
        A specific block of memory allocated by the OS of the size of 
        lazy_array.
    np.dtype.shape
        A shape object with dimensions (# timesteps, # of raster cells) in
        reference to lazy_array.
    np.dtype
        Data type of objects in lazy_array.
    '''
    logger.debug(f"Creating shared memory size {lazy_array.nbytes/ 10**6} Mb.")
    shm = shared_memory.SharedMemory(create=True, size=lazy_array.nbytes)
    shared_array = np.ndarray(lazy_array.shape, 
                              dtype=np.float32, 
                              buffer=shm.buf)
    # if your data is not float32, xarray will do an automatic conversion here
    # which consumes a lot more memory, forcings downloaded with this tool will 
    # work
    for start, end in get_index_chunks(lazy_array):
        # copy data from lazy to shared memory one chunk at a time
        shared_array[start:end] = lazy_array[start:end]

    time, x, y = shared_array.shape
    shared_array = shared_array.reshape(time, -1)

    return shm, shared_array.shape, shared_array.dtype

In [11]:
def weighted_sum_of_cells(flat_raster: np.ndarray, 
                          cell_ids: np.ndarray, 
                          factors: np.ndarray) -> np.ndarray:
    '''  
    Take an average of each forcing variable in a catchment. Create an output
    array initialized with zeros. Sum up the forcing variable and divide by the 
    sum of the cell weights to get an averaged forcing variable for the entire 
    catchment.

    Parameters
    ----------
    flat_raster : np.ndarray
        An array of dimensions (time, x*y) containing forcing variable values
        in each cell. Each element in the array corresponds to a cell ID.
    cell_ids : np.ndarray
        A list of the raster cell IDs that intersect the study catchment.
    factors : np.ndarray
        A list of the weights (coverages) of each cell in cell_ids.

    Returns
    -------
    np.ndarray
        An one-dimensional array, where each element corresponds to a timestep.
        Each element contains the averaged forcing value for the whole catchment
        over one timestep.
    '''
    result = np.zeros(flat_raster.shape[0])
    result = np.sum(flat_raster[:, cell_ids] * factors, axis=1)
    sum_of_weights = np.sum(factors)
    result /= sum_of_weights
    return result

In [12]:
def process_chunk_shared(variable: str,
                         times: np.ndarray,
                         shm_name: str,
                         shape: np.dtype.shape,
                         dtype: np.dtype, 
                         chunk: gpd.GeoDataFrame) -> xr.DataArray:
    '''  
    Process the gridded forcings chunk loaded into a SharedMemory block. 

    Parameters
    ----------
    variable : str
        Name of forcing variable to be processed.
    times : np.ndarray
        Timesteps in gridded forcings chunk.
    shm_name : str
        Unique name that identifies the SharedMemory block.
    shape : np.dtype.shape
        A shape object with dimensions (# timesteps, # of raster cells) in
        reference to the gridded forcings chunk.
    dtype : np.dtype
        Data type of objects in the gridded forcings chunk.
    chunk : gpd.GeoDataFrame
        A chunk of gridded forcings data.

    Returns
    -------
    xr.DataArray
        Averaged forcings data for each timestep for each catchment.
    '''
    existing_shm = shared_memory.SharedMemory(name=shm_name)
    raster = np.ndarray(shape, dtype=dtype, buffer=existing_shm.buf)
    results = []

    for catchment in chunk.index.unique():
        cell_ids = chunk.loc[catchment]["cell_id"]
        weights = chunk.loc[catchment]["coverage"]
        mean_at_timesteps = weighted_sum_of_cells(raster, cell_ids, weights)
        temp_da = xr.DataArray(
            mean_at_timesteps,
            dims=["time"],
            coords={"time": times},
            name=f"{variable}_{catchment}",
        )
        temp_da = temp_da.assign_coords(catchment=catchment)
        results.append(temp_da)
    existing_shm.close()
    return xr.concat(results, dim="catchment")

In [13]:
def add_APCP_SURFACE_to_dataset(dataset: xr.Dataset) -> xr.Dataset:
    '''Convert precipitation value to correct units.'''
    # precip_rate is mm/s
    # cfe says input atmosphere_water__liquid_equivalent_precipitation_rate is 
    # mm/h
    # nom says prcpnonc input is mm/s
    # technically should be kg/m^2/s at 1kg = 1l it equates to mm/s
    # nom says qinsur output is m/s, hopefully qinsur is converted to mm/h by 
    # ngen
    dataset["APCP_surface"] = dataset["precip_rate"] * 3600
    dataset["APCP_surface"].attrs["units"] = "mm h^-1" # ^-1 notation copied 
                                                        # from source data
    dataset["APCP_surface"].attrs["source_note"] = "This is just the "\
        "precip_rate variable converted to mm/h by multiplying by 3600"
    return dataset

In [14]:
def write_outputs(forcings_dir: Path,
                  variables: dict,
                  units: dict) -> None:
    '''  
    Write outputs to disk in the form of a NetCDF file, using dask clusters to
    facilitate parallel computing.

    Parameters
    ----------
    forcings_dir : Path
        Path to directory where outputs are to be stored.
    variables : dict
        Preset dictionary where the keys are forcing variable names and the 
        values are units.
    units : dict
        Dictionary where the keys are forcing variable names and the values are 
        units. Differs from variables, as this dictionary depends on the gridded
        forcing dataset.
    '''
    # start a dask cluster if there isn't one already running
    try:
        client = Client.current()
    except ValueError:
        cluster = LocalCluster()
        client = Client(cluster)
    temp_forcings_dir = forcings_dir / "temp"
    # Combine all variables into a single dataset using dask
    results = [xr.open_dataset(file, chunks="auto") 
               for file in temp_forcings_dir.glob("*.nc")]
    final_ds = xr.merge(results)
    for var in final_ds.data_vars:
        if var in units:
            final_ds[var].attrs["units"] = units[var]
        else:
            logger.warning(f"Variable {var} has no units")

    rename_dict = {}
    for key, value in variables.items():
        if key in final_ds:
            rename_dict[key] = value

    final_ds = final_ds.rename_vars(rename_dict)
    final_ds = add_APCP_SURFACE_to_dataset(final_ds)

    # this step halves the storage size of the forcings
    for var in final_ds.data_vars:
        final_ds[var] = final_ds[var].astype(np.float32)

    logger.info("Saving to disk")
    # The format for the netcdf is to support a legacy format
    # which is why it's a little "unorthodox"
    # There are no coordinates, just dimensions, catchment ids are stored in a 
    # 1d data var and time is stored in a 2d data var with the same time array 
    # for every catchment
    # time is stored as unix timestamps, units have to be set

    # add the catchment ids as a 1d data var
    final_ds["ids"] = final_ds["catchment"].astype(str)
    # time needs to be a 2d array of the same time array as unix timestamps for 
    # every catchment
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        time_array = (
            final_ds.time.astype("datetime64[s]").astype(np.int64).values//10**9
        )  ## convert from ns to s
    time_array = time_array.astype(np.int32) ## convert to int32 to save space
    final_ds = final_ds.drop_vars(["catchment", "time"]) ## drop the original 
                                                    # time and catchment vars
    final_ds = final_ds.rename_dims({"catchment": "catchment-id"}) # rename the 
                                                        # catchment dimension
    # add the time as a 2d data var, yes this is wasting disk space.
    final_ds["Time"] = (("catchment-id", "time"), 
                        [time_array for _ in range(len(final_ds["ids"]))])
    # set the time unit
    final_ds["Time"].attrs["units"] = "s"
    final_ds["Time"].attrs["epoch_start"] = "01/01/1970 00:00:00" # not needed 
                                            # but suppresses the ngen warning

    final_ds.to_netcdf(forcings_dir / "forcings.nc", engine="netcdf4")
    # close the datasets
    _ = [result.close() for result in results]
    final_ds.close()

    # clean up the temp files
    for file in temp_forcings_dir.glob("*.*"):
        file.unlink()
    temp_forcings_dir.rmdir()

In [None]:
def compute_zonal_stats(
    gdf: gpd.GeoDataFrame, merged_data: xr.Dataset, forcings_dir: Path
) -> None:
    '''  
    Compute zonal statistics in parallel for all timesteps over all desired 
    catchments. Create chunks of catchments and within those, chunks of 
    timesteps for memory management.

    Parameters
    ----------
    gdf : gpd.GeoDataFrame
        Contains identity and geometry information on desired catchments.
    merged_data : xr.Dataset
        Gridded forcing data that intersects with desired catchments.
    forcings_dir : Path
        Path to directory where outputs are to be stored.
    '''
    logger.info("Computing zonal stats in parallel for all timesteps")
    timer_start = time.time()
    num_partitions = multiprocessing.cpu_count() - 1
    if num_partitions > len(gdf):
        num_partitions = len(gdf)

    catchments = get_cell_weights_parallel(gdf, merged_data, num_partitions)

    units = get_units(merged_data)

    variables = {
                "LWDOWN": "DLWRF_surface",
                "PSFC": "PRES_surface",
                "Q2D": "SPFH_2maboveground",
                "RAINRATE": "precip_rate",
                "SWDOWN": "DSWRF_surface",
                "T2D": "TMP_2maboveground",
                "U2D": "UGRD_10maboveground",
                "V2D": "VGRD_10maboveground",
            }

    cat_chunks = np.array_split(catchments, num_partitions)

    # progress = Progress(
    #     TextColumn("[progress.description]{task.description}"),
    #     BarColumn(),
    #     "[progress.percentage]{task.percentage:>3.0f}%",
    #     TextColumn("{task.completed}/{task.total}"),
    #     "•",
    #     TextColumn(" Elapsed Time:"),
    #     TimeElapsedColumn(),
    #     TextColumn(" Remaining Time:"),
    #     TimeRemainingColumn(),
    # )

    timer = time.perf_counter()
    # variable_task = progress.add_task(
    #     "[cyan]Processing variables...", total=len(variables), elapsed=0
    # )
    # progress.start()
    for variable in variables.keys():
        # progress.update(variable_task, advance=1)
        # progress.update(variable_task, description=f"Processing {variable}")

        if variable not in merged_data.data_vars:
            logger.warning(f"Variable {variable} not in forcings, skipping")
            continue

        # to make sure this fits in memory, we need to chunk the data
        time_chunks = get_index_chunks(merged_data[variable])
        # chunk_task = progress.add_task("[purple] processing chunks", 
        #                                total=len(time_chunks))
        for i, times in enumerate(time_chunks):
            # progress.update(chunk_task, advance=1)
            start, end = times
            # select the chunk of time we want to process
            data_chunk = merged_data[variable].isel(time=slice(start,end))
            # put it in shared memory
            shm, shape, dtype = create_shared_memory(data_chunk)
            times = data_chunk.time.values
            print(times)
            # create a partial function to pass to the multiprocessing pool
            partial_process_chunk = partial(process_chunk_shared,
                                            variable,
                                            times,
                                            shm.name,
                                            shape,
                                            dtype)

            logger.debug(f"Processing variable: {variable}")
            # process the chunks of catchments in parallel
            with multiprocessing.Pool(num_partitions) as pool:
                variable_data = pool.map(partial_process_chunk, cat_chunks)
            del partial_process_chunk
            # clean up the shared memory
            shm.close()
            shm.unlink()
            logger.debug(f"Processed variable: {variable}")
            concatenated_da = xr.concat(variable_data, dim="catchment")
            # delete the data to free up memory
            del variable_data
            logger.debug(f"Concatenated variable: {variable}")
            # write this to disk now to save memory
            # xarray will monitor memory usage, but it doesn't account for the 
            # shared memory used to store the raster
            # This reduces memory usage by about 60%
            concatenated_da.to_dataset(name=variable).to_netcdf(forcings_dir/ 
                                                                "temp" / 
                                                                f"{variable}_{i}.nc")
        # Merge the chunks back together
        datasets = [xr.open_dataset(forcings_dir 
                                    / "temp" 
                                    / f"{variable}_{i}.nc") 
                                    for i in range(len(time_chunks))]
        result = xr.concat(datasets, dim="time")
        result.to_netcdf(forcings_dir / "temp" / f"{variable}.nc")
        # close the datasets
        result.close()
        _ = [dataset.close() for dataset in datasets]

        for file in forcings_dir.glob("temp/*_*.nc"):
            file.unlink()
    #     progress.remove_task(chunk_task)
    # progress.update(
    #     variable_task,
    #     description=f"Forcings processed in {time.perf_counter() - timer:2f}"\
    #     " seconds",
    # )
    # progress.stop()
    logger.info(
        f"Forcing generation complete! Zonal stats computed in "\
            f"{time.time() - timer_start:2f} seconds"
    )
    write_outputs(forcings_dir, variables, units)



In [16]:
# Set path to an example FORCINGS NetCDF downloaded from NWM 3.0 retrospective 
# S3 bucket
forc_path = "/media/volume/Clone_Imp_Data/FORCING/2008/"\
    "200801010000.LDASIN_DOMAIN1"

# Get CRS from the example NetCDF
projection = xr.open_dataset(forc_path, engine="h5netcdf").crs.esri_pe_string
logging.debug("Got projection from grid file")

In [17]:
# Read hydrofabric geodatabase
gdb_path = "../NWM_v3_hydrofabric.gdb/"
gdb = gpd.read_file(gdb_path, driver="FileGDB", layer="nwm_catchments_conus")
start_time = "2008-01-01 00:00:00"
end_time = "2008-01-07 23:00:00"

# gdb.set_geometry('geometry', inplace=True)
# gdb.to_crs(projection, inplace=True)

  return ogr_read(
  return ogr_read(


In [18]:
# # Set bounding box of desired state (in our case, AL)
# al_bounding_box = {'geometry': [Polygon([
#     (-88.4745951503515,30.222501133601334),
#     (-84.89247974539745,35.008322669916694),
#     (-88.4745951503515,35.008322669916694),
#     (-84.89247974539745,30.222501133601334)
#     ])]}
# al_box = gpd.GeoDataFrame(al_bounding_box, crs="EPSG:4326")
# al_box.bounds

# # Transform coordinates to the CRS of the forcing files
# al_box.to_crs(projection, inplace=True)

In [19]:
# print(al_box.bounds.loc[0])

In [20]:
# Clip forcing files to bounds of Alabama
# al_forcs = clip_dataset_to_bounds(
#     xr.open_mfdataset("/media/volume/Clone_Imp_Data/FORCING/2008/"\
#                       "200801*.LDASIN_DOMAIN1"),
#                       gdb.bounds.loc[0],
#                        "2008-01-01 00:00:00",
#                        "2008-01-07 23:00:00"
# )

In [21]:
t1 = datetime.datetime.now()

exp_dirname = '../runs/experiment_'+datetime.date.today().isoformat()+'/'
if not os.path.exists(exp_dirname):
    os.makedirs(exp_dirname)

output_name = exp_dirname + 'output.txt'

with open(output_name, 'w') as file:
    toprint = ('NDP run started at ' + 
                t1.strftime("%m/%d/%Y %H:%M:%S") + '\n')
    file.write(toprint)

In [22]:
forcs = xr.open_mfdataset("/media/volume/Clone_Imp_Data/FORCING/2008/2008010*.LDASIN_DOMAIN1",
                           engine='h5netcdf')

In [23]:
# A small dictionary for testing.
# headwater: tailwater
small_dict = {"445308": 445314,
              "445322": 445326,
              "445328": 445336}

In [24]:
logger.setLevel(logging.INFO)

In [25]:
def process_dict(studydict: dict,
                 gdb: gpd.GeoDataFrame,
                 forcs: xr.Dataset,
                 start_time: str, 
                 end_time: str) -> None:
    '''  
    Generate forcing files given a dictionary of desired catchment pairs.

    Parameters
    ----------
    studydict : dict
        Dictionary of desired catchment pairs. The keys are headwaters, and the
        values are the corresponding tailwaters.
    gdb : gpd.GeoDataFrame
        NWM 3.0 retrospective hydrofabric.
    forcs : xr.Dataset
        Gridded forcings file.
    start_time : str
        Desired start time in YYYY/MM/DD HH:MM:SS format.
    end_time : str
        Desired end time in YYYY/MM/DD HH:MM:SS format.
    '''
    for k, v in list(studydict.items()):

        output_file_k = Path(f"./test/outputs/{k}-forcs.nc")
        head_gdf = head_gdf_selection(k, gdb).to_crs(projection)

        head_forcs = clip_dataset_to_bounds(
            forcs, head_gdf.total_bounds, start_time, end_time
        )
        #logging.debug(f"head gdf bounds: {head_gdf.total_bounds}")
        #print(f"head gdf bounds: {head_gdf.total_bounds}")

        forcing_working_dir_k = Path(f"./test/{k}-working-dir")
        if not forcing_working_dir_k.exists():
            forcing_working_dir_k.mkdir(parents=True, exist_ok=True)

        temp_dir = forcing_working_dir_k / "temp"
        if not temp_dir.exists():
            temp_dir.mkdir(parents=True, exist_ok=True)
        
        compute_zonal_stats(head_gdf, head_forcs, forcing_working_dir_k)
            
        shutil.copy(forcing_working_dir_k / "forcings.nc", output_file_k)
        logging.info(f"Created forcings file: {output_file_k}")
        # remove the working directory
        shutil.rmtree(forcing_working_dir_k)

        # Process tailwater data
        output_file_v = Path(f"./test/outputs/{v}-forcs.nc")
        tail_gdf = tail_gdf_selection(k, v, gdb).to_crs(projection)
        #print(f"tail gdf bounds: {tail_gdf.total_bounds}")
        tail_forcs = clip_dataset_to_bounds(
            forcs, tail_gdf.total_bounds, start_time, end_time
        )
        logging.debug(f"tail gdf bounds: {tail_gdf.total_bounds}")

        forcing_working_dir_v = Path(f"./test/{v}-working-dir")
        if not forcing_working_dir_v.exists():
            forcing_working_dir_v.mkdir(parents=True, exist_ok=True)

        temp_dir = forcing_working_dir_v / "temp" 
        if not temp_dir.exists():
            temp_dir.mkdir(parents=True, exist_ok=True)

        compute_zonal_stats(tail_gdf, tail_forcs, forcing_working_dir_v)

        shutil.copy(forcing_working_dir_v / "forcings.nc", output_file_v)
        logging.info(f"Created forcings file: {output_file_v}")
        # remove the working directory
        shutil.rmtree(forcing_working_dir_v)

In [26]:
process_dict(small_dict, gdb, forcs, start_time, end_time)

INFO:__main__:Selected time range and clipped to bounds
INFO:__main__:Computing zonal stats in parallel for all timesteps


['2008-01-01T00:00:00.000000000' '2008-01-01T01:00:00.000000000'
 '2008-01-01T02:00:00.000000000' '2008-01-01T03:00:00.000000000'
 '2008-01-01T04:00:00.000000000' '2008-01-01T05:00:00.000000000'
 '2008-01-01T06:00:00.000000000' '2008-01-01T07:00:00.000000000'
 '2008-01-01T08:00:00.000000000' '2008-01-01T09:00:00.000000000'
 '2008-01-01T10:00:00.000000000' '2008-01-01T11:00:00.000000000'
 '2008-01-01T12:00:00.000000000' '2008-01-01T13:00:00.000000000'
 '2008-01-01T14:00:00.000000000' '2008-01-01T15:00:00.000000000'
 '2008-01-01T16:00:00.000000000' '2008-01-01T17:00:00.000000000'
 '2008-01-01T18:00:00.000000000' '2008-01-01T19:00:00.000000000'
 '2008-01-01T20:00:00.000000000' '2008-01-01T21:00:00.000000000'
 '2008-01-01T22:00:00.000000000' '2008-01-01T23:00:00.000000000'
 '2008-01-02T00:00:00.000000000' '2008-01-02T01:00:00.000000000'
 '2008-01-02T02:00:00.000000000' '2008-01-02T03:00:00.000000000'
 '2008-01-02T04:00:00.000000000' '2008-01-02T05:00:00.000000000'
 '2008-01-02T06:00:00.000

INFO:__main__:Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds
INFO:__main__:Saving to disk
INFO:__main__:Selected time range and clipped to bounds
INFO:__main__:Computing zonal stats in parallel for all timesteps


['2008-01-01T00:00:00.000000000' '2008-01-01T01:00:00.000000000'
 '2008-01-01T02:00:00.000000000' '2008-01-01T03:00:00.000000000'
 '2008-01-01T04:00:00.000000000' '2008-01-01T05:00:00.000000000'
 '2008-01-01T06:00:00.000000000' '2008-01-01T07:00:00.000000000'
 '2008-01-01T08:00:00.000000000' '2008-01-01T09:00:00.000000000'
 '2008-01-01T10:00:00.000000000' '2008-01-01T11:00:00.000000000'
 '2008-01-01T12:00:00.000000000' '2008-01-01T13:00:00.000000000'
 '2008-01-01T14:00:00.000000000' '2008-01-01T15:00:00.000000000'
 '2008-01-01T16:00:00.000000000' '2008-01-01T17:00:00.000000000'
 '2008-01-01T18:00:00.000000000' '2008-01-01T19:00:00.000000000'
 '2008-01-01T20:00:00.000000000' '2008-01-01T21:00:00.000000000'
 '2008-01-01T22:00:00.000000000' '2008-01-01T23:00:00.000000000'
 '2008-01-02T00:00:00.000000000' '2008-01-02T01:00:00.000000000'
 '2008-01-02T02:00:00.000000000' '2008-01-02T03:00:00.000000000'
 '2008-01-02T04:00:00.000000000' '2008-01-02T05:00:00.000000000'
 '2008-01-02T06:00:00.000

INFO:__main__:Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds
INFO:__main__:Saving to disk
INFO:__main__:Selected time range and clipped to bounds
INFO:__main__:Computing zonal stats in parallel for all timesteps


['2008-01-01T00:00:00.000000000' '2008-01-01T01:00:00.000000000'
 '2008-01-01T02:00:00.000000000' '2008-01-01T03:00:00.000000000'
 '2008-01-01T04:00:00.000000000' '2008-01-01T05:00:00.000000000'
 '2008-01-01T06:00:00.000000000' '2008-01-01T07:00:00.000000000'
 '2008-01-01T08:00:00.000000000' '2008-01-01T09:00:00.000000000'
 '2008-01-01T10:00:00.000000000' '2008-01-01T11:00:00.000000000'
 '2008-01-01T12:00:00.000000000' '2008-01-01T13:00:00.000000000'
 '2008-01-01T14:00:00.000000000' '2008-01-01T15:00:00.000000000'
 '2008-01-01T16:00:00.000000000' '2008-01-01T17:00:00.000000000'
 '2008-01-01T18:00:00.000000000' '2008-01-01T19:00:00.000000000'
 '2008-01-01T20:00:00.000000000' '2008-01-01T21:00:00.000000000'
 '2008-01-01T22:00:00.000000000' '2008-01-01T23:00:00.000000000'
 '2008-01-02T00:00:00.000000000' '2008-01-02T01:00:00.000000000'
 '2008-01-02T02:00:00.000000000' '2008-01-02T03:00:00.000000000'
 '2008-01-02T04:00:00.000000000' '2008-01-02T05:00:00.000000000'
 '2008-01-02T06:00:00.000

INFO:__main__:Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds
INFO:__main__:Saving to disk
INFO:__main__:Selected time range and clipped to bounds
INFO:__main__:Computing zonal stats in parallel for all timesteps


['2008-01-01T00:00:00.000000000' '2008-01-01T01:00:00.000000000'
 '2008-01-01T02:00:00.000000000' '2008-01-01T03:00:00.000000000'
 '2008-01-01T04:00:00.000000000' '2008-01-01T05:00:00.000000000'
 '2008-01-01T06:00:00.000000000' '2008-01-01T07:00:00.000000000'
 '2008-01-01T08:00:00.000000000' '2008-01-01T09:00:00.000000000'
 '2008-01-01T10:00:00.000000000' '2008-01-01T11:00:00.000000000'
 '2008-01-01T12:00:00.000000000' '2008-01-01T13:00:00.000000000'
 '2008-01-01T14:00:00.000000000' '2008-01-01T15:00:00.000000000'
 '2008-01-01T16:00:00.000000000' '2008-01-01T17:00:00.000000000'
 '2008-01-01T18:00:00.000000000' '2008-01-01T19:00:00.000000000'
 '2008-01-01T20:00:00.000000000' '2008-01-01T21:00:00.000000000'
 '2008-01-01T22:00:00.000000000' '2008-01-01T23:00:00.000000000'
 '2008-01-02T00:00:00.000000000' '2008-01-02T01:00:00.000000000'
 '2008-01-02T02:00:00.000000000' '2008-01-02T03:00:00.000000000'
 '2008-01-02T04:00:00.000000000' '2008-01-02T05:00:00.000000000'
 '2008-01-02T06:00:00.000

INFO:__main__:Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds
INFO:__main__:Saving to disk
INFO:__main__:Selected time range and clipped to bounds
INFO:__main__:Computing zonal stats in parallel for all timesteps


['2008-01-01T00:00:00.000000000' '2008-01-01T01:00:00.000000000'
 '2008-01-01T02:00:00.000000000' '2008-01-01T03:00:00.000000000'
 '2008-01-01T04:00:00.000000000' '2008-01-01T05:00:00.000000000'
 '2008-01-01T06:00:00.000000000' '2008-01-01T07:00:00.000000000'
 '2008-01-01T08:00:00.000000000' '2008-01-01T09:00:00.000000000'
 '2008-01-01T10:00:00.000000000' '2008-01-01T11:00:00.000000000'
 '2008-01-01T12:00:00.000000000' '2008-01-01T13:00:00.000000000'
 '2008-01-01T14:00:00.000000000' '2008-01-01T15:00:00.000000000'
 '2008-01-01T16:00:00.000000000' '2008-01-01T17:00:00.000000000'
 '2008-01-01T18:00:00.000000000' '2008-01-01T19:00:00.000000000'
 '2008-01-01T20:00:00.000000000' '2008-01-01T21:00:00.000000000'
 '2008-01-01T22:00:00.000000000' '2008-01-01T23:00:00.000000000'
 '2008-01-02T00:00:00.000000000' '2008-01-02T01:00:00.000000000'
 '2008-01-02T02:00:00.000000000' '2008-01-02T03:00:00.000000000'
 '2008-01-02T04:00:00.000000000' '2008-01-02T05:00:00.000000000'
 '2008-01-02T06:00:00.000

INFO:__main__:Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds
INFO:__main__:Saving to disk
INFO:__main__:Selected time range and clipped to bounds
INFO:__main__:Computing zonal stats in parallel for all timesteps


['2008-01-01T00:00:00.000000000' '2008-01-01T01:00:00.000000000'
 '2008-01-01T02:00:00.000000000' '2008-01-01T03:00:00.000000000'
 '2008-01-01T04:00:00.000000000' '2008-01-01T05:00:00.000000000'
 '2008-01-01T06:00:00.000000000' '2008-01-01T07:00:00.000000000'
 '2008-01-01T08:00:00.000000000' '2008-01-01T09:00:00.000000000'
 '2008-01-01T10:00:00.000000000' '2008-01-01T11:00:00.000000000'
 '2008-01-01T12:00:00.000000000' '2008-01-01T13:00:00.000000000'
 '2008-01-01T14:00:00.000000000' '2008-01-01T15:00:00.000000000'
 '2008-01-01T16:00:00.000000000' '2008-01-01T17:00:00.000000000'
 '2008-01-01T18:00:00.000000000' '2008-01-01T19:00:00.000000000'
 '2008-01-01T20:00:00.000000000' '2008-01-01T21:00:00.000000000'
 '2008-01-01T22:00:00.000000000' '2008-01-01T23:00:00.000000000'
 '2008-01-02T00:00:00.000000000' '2008-01-02T01:00:00.000000000'
 '2008-01-02T02:00:00.000000000' '2008-01-02T03:00:00.000000000'
 '2008-01-02T04:00:00.000000000' '2008-01-02T05:00:00.000000000'
 '2008-01-02T06:00:00.000

INFO:__main__:Forcing generation complete! Zonal stats computed in {time.time() - timer_start:2f} seconds
INFO:__main__:Saving to disk


In [27]:
t2 = datetime.datetime.now()
with open(output_name, 'a') as file:
    toprint = ('NDP run elapsed time in s: ' 
                + str((t2-t1).total_seconds()) + '\n')
    file.write(toprint)