In [None]:
import ast
from datetime import datetime
import json
import logging
from pathlib import Path
import re
import time
from typing import List, Tuple
from tempfile import NamedTemporaryFile

import dask.array as da
from dask.diagnostics import ProgressBar
from dask.distributed import Client, as_completed
import dask_geopandas as dgd 
import hydra
import geopandas as gpd
import numpy as np
from omegaconf import DictConfig, OmegaConf
import pandas as pd
from pyproj import CRS
from tqdm.notebook import tqdm
import torch
import xarray as xr
import zarr

log = logging.getLogger(__name__)

In [None]:
cfg = {
  "name": "MERIT",
  "routing_data_path": "/data/tkb5476/projects/gridded_routing/data/",
  "marquette_data_path": "/data/tkb5476/projects/marquette/data/",
  "dx": 2000,
  "buffer": 0.3334,
  "units": "mm/day",
  "date_codes": "${data_path}/date_codes.json",
  "is_streamflow_split": True,
  "start_date": "01-01-1980",
  "end_date": "12-31-2020",
  "num_cores": 20,
  "continent": 7,
  "area": 3,
  "num_partitions": 64,
  "save_name": "${continent}${area}",
  "gage_locations" : "${routing_data_path}/observations/gage_example.csv",
  "save_paths": {
    "attributes": "${marquette_data_path}/streamflow/attributes_dpl_v3.csv",
    "flow_lines": "${marquette_data_path}/${name}/raw/flowlines",
    "basins": "${marquette_data_path}/${name}/raw/basins/cat_pfaf_${continent}${area}_MERIT_Hydro_v07_Basins_v01_bugfix1.shp",
    "huc10": "${marquette_data_path}/HUC/huc_10_CONUS.shp",
    "streamflow_files": "${marquette_data_path}/streamflow/predictions/dpl_v3/",
    "usgs_flowline_mapping": "${routing_data_path}/HUC/usgs_flowline_mapping.json"
  },
  "zarr": {
    "edges": "${marquette_data_path}/${name}/zarr/graph/${name}_edges/",
    "HUC_TM": "${marquette_data_path}/${name}/zarr/TMs/PFAF_${continent}${area}",
    "MERIT_TM": "${marquette_data_path}/${name}/zarr/TMs/MERIT_FLOWLINES_${continent}${area}",
    "streamflow": "${marquette_data_path}/streamflow/zarr/dpl_v3/${save_name}",
    "csr_matrix": "${marquette_data_path}/${name}/zarr/graph/csr_network_matrix/"
  }
}
cfg = OmegaConf.create(cfg)

In [None]:
def downstream_map(id_index, merit_flowlines, rows, cols, data, id_to_index, visited):
    """
    Perform a recursive downstream mapping starting from a given node ID and record the downstream connectivity.

    Parameters
    ----------
    id_index : int
        The index of the starting node ID for the mapping.
    merit_flowlines : zarr.core.Array
        The zarr array containing the merit flowlines data.
    rows : list
        List to store row indices for the COO matrix.
    cols : list
        List to store column indices for the COO matrix.
    data : list
        List to store data values (connectivity) for the COO matrix.
    id_to_index : dict
        Mapping from string IDs to numerical indices.
    visited : set
        A set of visited node indices.
    """
    if id_index in visited:
        return
    visited.add(id_index)

    ds_id = merit_flowlines.ds[id_index]
    if ds_id == '0_0':
        return  # Stop mapping when the river ends

    if ds_id in id_to_index:
        ds_index = id_to_index[ds_id]
        rows.append(id_index)
        cols.append(ds_index)
        data.append(1)  # Connectivity
        downstream_map(ds_index, merit_flowlines, rows, cols, data, id_to_index, visited)

In [None]:
id_to_index = {id_val: idx for idx, id_val in enumerate(merit_flowlines.id[:])}
rows, cols, data, visited = [], [], [], set()

for id_index in tqdm(range(len(merit_flowlines.id)), desc="Mapping Downstream"):
    downstream_map(id_index, merit_flowlines, rows, cols, data, id_to_index, visited)

In [None]:
rows_tensor = torch.tensor(rows, dtype=torch.int64)
cols_tensor = torch.tensor(cols, dtype=torch.int64)
data_tensor = torch.tensor(data, dtype=torch.float32)
coo_matrix = torch.sparse_coo_tensor(indices=torch.stack([rows_tensor, cols_tensor]), 
                                     values=data_tensor, 
                                     size=(len(merit_flowlines.id), len(merit_flowlines.id)))

# Convert to CSR format
csr_matrix = coo_matrix.to_sparse_csr()
print(f"csr_matrix: {csr_matrix}")

In [None]:
zarr_storage_path = Path(cfg.zarr.csr_matrix)  # Replace with your path

# Open or create the Zarr group
root = zarr.open_group(zarr_storage_path, mode="a")
csr_data = root.create_group(cfg.zone)

# Create Zarr datasets and store the tensor data
csr_data.create_dataset('rows', data=rows_tensor.numpy(), chunks=(10000,), dtype='i8')
csr_data.create_dataset('cols', data=cols_tensor.numpy(), chunks=(10000,), dtype='i8')
csr_data.create_dataset('data', data=data_tensor.numpy(), chunks=(10000,), dtype='f4')
print(root.tree())

### Testing

In [None]:
start_node_index = 0  # You can change this to any valid index

# Initialize data structures
id_to_index = {id_val: idx for idx, id_val in enumerate(merit_flowlines.id[:])}
rows, cols, data, visited = [], [], [], set()

# Run the mapping for the selected node
downstream_map(start_node_index, merit_flowlines, rows, cols, data, id_to_index, visited)

rows_tensor = torch.tensor(rows, dtype=torch.int64)
cols_tensor = torch.tensor(cols, dtype=torch.int64)
data_tensor = torch.tensor(data, dtype=torch.float32)
coo_matrix = torch.sparse_coo_tensor(indices=torch.stack([rows_tensor, cols_tensor]), 
                                     values=data_tensor, 
                                     size=(len(merit_flowlines.id), len(merit_flowlines.id)))

# Convert to CSR format
csr_matrix = coo_matrix.to_sparse_csr()
print(f"csr_matrix: {csr_matrix}")

In [None]:
for idx in rows:
    print(f"id: {merit_flowlines.id[idx]} -> ds: {merit_flowlines.ds[idx]}")