# This notebook was used as a sandbox create the functionality within Marquette

In [1]:
from datetime import datetime
import json
import logging
from pathlib import Path
import re
import time
from typing import List, Tuple
from tempfile import NamedTemporaryFile

import dask.array as da
from dask.diagnostics import ProgressBar
from dask.distributed import Client, as_completed
import dask_geopandas as dgd 
import hydra
import geopandas as gpd
import numpy as np
from omegaconf import DictConfig, OmegaConf
import pandas as pd
from pyproj import CRS
from tqdm.notebook import tqdm
import xarray as xr
import zarr

log = logging.getLogger(__name__)
client = Client(dashboard_address=':8989')
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8989/status,

0,1
Dashboard: http://127.0.0.1:8989/status,Workers: 8
Total threads: 32,Total memory: 251.53 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:45099,Workers: 8
Dashboard: http://127.0.0.1:8989/status,Total threads: 32
Started: Just now,Total memory: 251.53 GiB

0,1
Comm: tcp://127.0.0.1:38905,Total threads: 4
Dashboard: http://127.0.0.1:43409/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:45879,
Local directory: /tmp/dask-scratch-space/worker-4j4r2yxb,Local directory: /tmp/dask-scratch-space/worker-4j4r2yxb

0,1
Comm: tcp://127.0.0.1:45105,Total threads: 4
Dashboard: http://127.0.0.1:36271/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:43247,
Local directory: /tmp/dask-scratch-space/worker-9dky2gki,Local directory: /tmp/dask-scratch-space/worker-9dky2gki

0,1
Comm: tcp://127.0.0.1:37215,Total threads: 4
Dashboard: http://127.0.0.1:40601/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:39091,
Local directory: /tmp/dask-scratch-space/worker-xjhlmhda,Local directory: /tmp/dask-scratch-space/worker-xjhlmhda

0,1
Comm: tcp://127.0.0.1:44673,Total threads: 4
Dashboard: http://127.0.0.1:45845/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:44413,
Local directory: /tmp/dask-scratch-space/worker-v3pjszrh,Local directory: /tmp/dask-scratch-space/worker-v3pjszrh

0,1
Comm: tcp://127.0.0.1:46113,Total threads: 4
Dashboard: http://127.0.0.1:41179/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:34123,
Local directory: /tmp/dask-scratch-space/worker-5tolqnrq,Local directory: /tmp/dask-scratch-space/worker-5tolqnrq

0,1
Comm: tcp://127.0.0.1:36629,Total threads: 4
Dashboard: http://127.0.0.1:39967/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:40419,
Local directory: /tmp/dask-scratch-space/worker-o4qa37th,Local directory: /tmp/dask-scratch-space/worker-o4qa37th

0,1
Comm: tcp://127.0.0.1:45377,Total threads: 4
Dashboard: http://127.0.0.1:43483/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:35425,
Local directory: /tmp/dask-scratch-space/worker-sk526mu2,Local directory: /tmp/dask-scratch-space/worker-sk526mu2

0,1
Comm: tcp://127.0.0.1:44747,Total threads: 4
Dashboard: http://127.0.0.1:33645/status,Memory: 31.44 GiB
Nanny: tcp://127.0.0.1:37965,
Local directory: /tmp/dask-scratch-space/worker-ws1fiu1j,Local directory: /tmp/dask-scratch-space/worker-ws1fiu1j


In [3]:
cfg = OmegaConf.load("/data/tkb5476/projects/marquette/marquette/conf/config.yaml")

### Separate HUCs into smaller pieces (ONLY RUN ONCE) 

In [76]:
# gdf = gpd.read_file(Path(cfg.save_paths.huc10))
# grouped = gdf.groupby('HUC10')

In [77]:
# for huc10_value, group in tqdm(grouped, desc='Processing shp files'):
#     out_path = f'{cfg.save_paths.singular_huc10}/{huc10_value}.shp'
#     group.to_file(out_path)

### Create HUC -> MERIT TM

#### TODO: debugging

In [6]:
def join_geospatial_data(cfg: DictConfig) -> gpd.GeoDataFrame:
    """
    Joins two geospatial datasets based on the intersection of centroids of one dataset with the geometries of the other.

    Args:
    huc10_path (str): File path to the HUC10 shapefile.
    basins_path (str): File path to the basins shapefile.

    Returns:
    gpd.GeoDataFrame: The resulting joined GeoDataFrame.
    """
    huc10_gdf = gpd.read_file(Path(cfg.save_paths.huc10)).to_crs(epsg=4326)
    basins_gdf = gpd.read_file(Path(cfg.save_paths.basins))
    basins_gdf['centroid'] = basins_gdf.geometry.centroid
    joined_gdf = gpd.sjoin(basins_gdf.set_geometry('centroid'), huc10_gdf, how='left', op='intersects')
    joined_gdf.set_geometry('geometry', inplace=True)
    return joined_gdf


In [7]:
edges = zarr.open_group(cfg.zarr.edges)
gdf = join_geospatial_data(cfg)


  basins_gdf['centroid'] = basins_gdf.geometry.centroid
  exec(code_obj, self.user_global_ns, self.user_ns)


In [84]:
def create_TM(cfg: DictConfig, gdf: gpd.GeoDataFrame):
    """
    Create a Transfer Matrix (TM) from GeoDataFrame.

    Args:
        cfg (DictConfig): Hydra configuration object containing settings.
        gdf (GeoDataFrame): GeoDataFrame containing geographical data.
    """
    gdf = gdf.dropna(subset=['HUC10'])
    huc10_ids = gdf["HUC10"].unique()
    merit_ids = gdf["COMID"].unique()
    huc10_ids.sort()
    merit_ids.sort()
    data_array = xr.DataArray(np.zeros((len(huc10_ids), len(merit_ids))),
                              dims=["HUC10", "COMID"],
                              coords={"HUC10": huc10_ids, "COMID": merit_ids})
    for idx, huc_id in enumerate(tqdm(huc10_ids, desc="creating TM")):
        merit_basins = gdf[gdf['HUC10'] == str(huc_id)]
        total_area = merit_basins.iloc[0]["area_new"]

        for j, basin in merit_basins.iterrows():
            unit_area = basin.unitarea / total_area
            data_array.loc[huc_id, basin.COMID] = unit_area
    xr_dataset = xr.Dataset(
        data_vars={"TM": data_array},
        coords={"HUC10": huc10_ids, "COMID": merit_ids},
        attrs={"description": "HUC10 -> MERIT Transition Matrix"}
    )
    print("Saving Zarr Data")
    zarr_path = Path(cfg.zarr.HUC_TM)
    xr_dataset.to_zarr(zarr_path, mode='w')
    zarr_hierarchy = zarr.open_group(Path(cfg.zarr.HUC_TM), mode='r')
    log.info(f"TM saved to Zarr file at {zarr_path}")
    # print("Saving CSV Data")
    # df = xr_dataset.to_dataframe().unstack('COMID')['TM']
    # df.to_csv(Path(cfg.csv.TM), compression="gzip")
    # log.info("Finished Data Extraction")
    return zarr_hierarchy

# np.random.seed(0)
# temperature = 15 + 8 * np.random.randn(2, 2, 3)
# precipitation = 10 * np.random.rand(2, 2, 3)
# lon = [[-99.83, -99.32], [-99.79, -99.23]]
# lat = [[42.25, 42.21], [42.63, 42.59]]
# time = pd.date_range("2014-09-06", periods=3)
# reference_time = pd.Timestamp("2014-09-05")
# ds = xr.Dataset(
#     data_vars=dict(
#         temperature=(["x", "y", "time"], temperature),
#         precipitation=(["x", "y", "time"], precipitation),
#     ),
#     coords=dict(
#         lon=(["x", "y"], lon),
#         lat=(["x", "y"], lat),
#         time=time,
#         reference_time=reference_time,
#     ),
#     attrs=dict(description="Weather related data."),
# )
# <xarray.Dataset>
# Dimensions:         (x: 2, y: 2, time: 3)
# Coordinates:
#     lon             (x, y) float64 -99.83 -99.32 -99.79 -99.23
#     lat             (x, y) float64 42.25 42.21 42.63 42.59
#   * time            (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08
#     reference_time  datetime64[ns] 2014-09-05
# Dimensions without coordinates: x, y
# Data variables:
#     temperature     (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63
#     precipitation   (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805
# Attributes:
#     description:  Weather related data.

In [81]:
def plot_histogram(df: pd.DataFrame, num_bins: int = 100) -> None:
    """
    Creates and displays a histogram for the sum of values in each row of the provided DataFrame.

    Args:
    df (pd.DataFrame): A Pandas DataFrame whose row sums will be used for the histogram.
    num_bins (int, optional): The number of bins for the histogram. Defaults to 100.

    The function calculates the minimum, median, mean, and maximum values of the row sums
    and displays these as vertical lines on the histogram.
    """
    series = df.sum(axis=1)
    plt.figure(figsize=(10, 6))
    series.hist(bins=num_bins)
    plt.xlabel(r'Ratio of  $\sum$ MERIT basin area to HUC10 basin areas')
    plt.ylabel('Number of HUC10s')
    plt.title(r'Distribution of $\sum$ MERIT area / HUC10 basin area')
    min_val = series.min()
    median_val = series.median()
    mean_val = series.mean()
    max_val = series.max()
    plt.axvline(min_val, color='grey', linestyle='dashed', linewidth=2, label=f'Min: {min_val:.3f}')
    plt.axvline(median_val, color='blue', linestyle='dashed', linewidth=2, label=f'Median: {median_val:.3f}')
    plt.axvline(mean_val, color='red', linestyle='dashed', linewidth=2, label=f'Mean: {mean_val:.3f}')
    plt.axvline(max_val, color='green', linestyle='dashed', linewidth=2, label=f'Max: {max_val:.3f}')
    plt.legend()
    plt.show()

In [85]:
start = time.perf_counter()
# overlayed_merit_basins = join_geospatial_data(cfg)
zarr_data = create_TM(cfg, overlayed_merit_basins)
# zarr_dataset, df = create_TM(cfg, overlayed_merit_basins)
# plot_histogram(df)
end = time.perf_counter()
print(f"This took: {(end - start):.6f} seconds")

creating TM:   0%|          | 0/1573 [00:00<?, ?it/s]

Saving Zarr Data
This took: 10.727237 seconds


In [86]:
zarr_data.info

0,1
Name,/
Type,zarr.hierarchy.Group
Read-only,True
Store type,zarr.storage.DirectoryStore
No. members,3
No. arrays,3
No. groups,0
Arrays,"COMID, HUC10, TM"


In [87]:
from scipy.sparse import csr_matrix
zarr_data["HUC10"][:]
csr_matrix(zarr_data["TM"][:])
# hm_TM = csr_matrix(huc_to_merit_TM.drop("HUC10", axis=1).values)

<1573x16933 sparse matrix of type '<class 'numpy.float64'>'
	with 16933 stored elements in Compressed Sparse Row format>

In [88]:
zarr_data.tree()

Tree(nodes=(Node(disabled=True, name='/', nodes=(Node(disabled=True, icon='table', name='COMID (16933,) int64'…

# CREATE MERIT -> MERIT FLOWLINES TM

In [5]:
edges_group = zarr.open(Path(cfg.zarr.edges))
huc_to_merit_TM = zarr.open(Path(cfg.zarr.HUC_TM), mode='r')
edges_group.tree()

Tree(nodes=(Node(disabled=True, name='/', nodes=(Node(disabled=True, icon='table', name='coords (108117,) <U30…

In [4]:
huc_to_merit_TM.tree()

Tree(nodes=(Node(disabled=True, name='/', nodes=(Node(disabled=True, icon='table', name='COMID (16933,) int64'…

In [11]:
# condition = (edges_group.merit_basin == 78006758).compute()
# filtered_edges = edges_group.where(condition, drop=True)
# filtered_edges_df = filtered_edges.to_dataframe()
# total_length = filtered_edges_df['len'].sum()

# # Calculate the proportion for each reach
# proportions = filtered_edges_df['len'] / total_length

# filtered_edges_df['len']

In [12]:
# huc_to_merit_xr = xr.open_zarr(Path(cfg.zarr.HUC_TM))
# huc_to_merit_ddf = huc_to_merit_xr.to_dask_dataframe()
# huc_to_merit_ddf = huc_to_merit_ddf.repartition(npartitions=96)
# edges_group_future = client.scatter(edges_group, broadcast=True)

In [13]:
# def process_partition(partition, edges_group_future):
#     COMIDs = partition.COMID.unique()
#     results_df = pd.DataFrame(0, index=COMIDs, columns=edges_group_future.id.values)

#     for basin_id in COMIDs:
#         # Apply condition and filtering
#         condition = edges_group_future.merit_basin == basin_id
#         filtered_edges = edges_group_future.where(condition, drop=True)

#         # Calculate total length and proportions
#         total_length = filtered_edges.len.sum()
#         proportions = filtered_edges.len / total_length

#         # Assign the proportions to the correct 'River_Graph_ID' in the DataFrame
#         for reach_id, proportion in zip(filtered_edges.id.values, proportions.values):
#             results_df.loc[basin_id, reach_id] = proportion

#     return results_df


In [14]:
# river_graph_ids = edges_group.id.values
# meta = pd.DataFrame({rg_id: pd.Series(dtype=float) for rg_id in river_graph_ids})
# results = edges_ddf.map_partitions(process_partition, edges_group=edges_group, COMIDs=COMIDs, meta=meta)

In [14]:
# # Use Dask's ProgressBar to monitor the progress
# with ProgressBar():
#     # Compute the results across all partitions
#     computed_results = results.compute()
# computed_results
merit_basin = edges_group.merit_basin[:]
indices = np.where(merit_basin == 78023389)[0]
indices

(array([0]),)

In [36]:
COMIDs = huc_to_merit_TM.COMID[:]
river_graph_ids = edges_group.id[:]
merit_basin = edges_group.merit_basin[:]
river_graph_len = edges_group.len[:]
proportion_array = np.zeros((len(COMIDs), len(river_graph_ids)))
for i, basin_id in enumerate(tqdm(COMIDs, desc="Processing River flowlines")):
    indices = np.where(merit_basin == basin_id)[0]

    total_length = np.sum(river_graph_len[indices])
    if total_length == 0:
        print("Basin not found:", basin_id)
        continue
    proportions = river_graph_len[indices] / total_length
    for idx, proportion in zip(indices, proportions):
        column_index = np.where(river_graph_ids == river_graph_ids[idx])[0][0]
        proportion_array[i, column_index] = proportion

Processing River flowlines:   0%|          | 0/16933 [00:00<?, ?it/s]

In [37]:
# data = [(comid, rg_id, prop) for (comid, rg_id), prop in results.items()]
# df = pd.DataFrame(data, columns=['COMID', 'River_Graph_ID', 'Proportion'])
# pivot_df = df.pivot(index='COMID', columns='River_Graph_ID', values='Proportion')
# pivot_df = pivot_df.fillna(0)
# pivot_df.shape

In [46]:
start = time.perf_counter()
df = pd.DataFrame(proportion_array, index=COMIDs, columns=river_graph_ids)
df.to_csv(Path(cfg.csv.MERIT_TM), compression="gzip")
end = time.perf_counter()
print(f"This took: {(end - start):.6f} seconds")

This took: 3107.829549 seconds


In [45]:
start = time.perf_counter()
zarr_group = zarr.open_group(Path(cfg.zarr.MERIT_TM), mode='w')
zarr_group.create_dataset('TM', data=proportion_array)
zarr_group.create_dataset('COMIDs', data=COMIDs)
zarr_group.create_dataset('EDGEIDs', data=river_graph_ids)
end = time.perf_counter()
print(f"This took: {(end - start):.6f} seconds")

This took: 8.394733 seconds


In [23]:
# Checking previous SRB TM
df = pd.read_csv("/data/tkb5476/projects/marquette/data/MERIT/streamflow/TMs/merit_to_srb_river_graph_TM.csv.gz", compression="gzip")

In [24]:
df.head()

Unnamed: 0,73001927_0,73001927_1,73001936_0,73001936_1,73001936_2,73001936_3,73001936_4,73001936_5,73001936_6,73001936_7,...,73006599_2,73006599_3,73006599_4,73006599_5,73006599_6,73006600_0,73006601_0,73006601_1,73006602_0,73006603_0
0,0.5,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
