In [None]:
import numpy as np
import pandas as pd
from pytides2.tide import Tide
import xarray as xr

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(
    n_workers=128,          # one worker per thread
    threads_per_worker=1,   # HDF5/xarray likes this
    memory_limit="2GB"      # or ~ (503/256) GiB
)
client = Client(cluster)
client

In [None]:
def pytides_to_df(pytides_tide: Tide) -> pd.DataFrame:
    constituent_names = [c.name.upper() for c in pytides_tide.model["constituent"]]
    return pd.DataFrame(pytides_tide.model, index=constituent_names).drop(
        "constituent",
        axis=1,
    )

def pytide_get_coefs(ts: pd.Series, resample: int = None) -> dict:
    if resample is not None:
        ts = ts.resample(f"{resample}min").mean()
        ts = ts.shift(freq=f"{resample / 2}min")  # Center the resampled points
    ts = ts.dropna()
    return Tide.decompose(ts.values, ts.index.to_pydatetime())[0]

def reduce_coef_to_fes(df: pd.DataFrame, cnst: list, verbose: bool = False):
    res = pd.DataFrame(0.0, index=cnst, columns=df.columns)
    common_constituents = df.index.intersection(cnst)
    res.loc[common_constituents] = df.loc[common_constituents]

    not_in_fes_df = df[~df.index.isin(cnst)]
    not_in_fes = not_in_fes_df.index.tolist()
    not_in_fes_amps = not_in_fes_df["amplitude"].round(3).tolist()
    missing_fes = set(cnst) - set(df.index)

    if verbose:
        print(f"Constituents found but not in FES: {not_in_fes}")
        print(f"Their amplitudes: {not_in_fes_amps}")
        if missing_fes:
            print(
                f"FES constituents missing from analysis (set to 0): {sorted(missing_fes)}",
            )

    return res

In [None]:
from natsort import natsorted
import glob

files = natsorted(glob.glob("/project/home/p200764/schism_runs/validation/schism3d/run4/20220*/outputs/out2d_*.nc"))
files[-5:]

In [None]:
ds = xr.open_mfdataset(
    files,
    concat_dim="time",
    combine="nested",
    parallel=True
)
# ds = ds.chunk({"time": -1, "nSCHISM_hgrid_node": 5000})  # now safe for Dask
ds

In [None]:
elev = ds.elevation.chunk({"time": -1, "nSCHISM_hgrid_node": 500})
elev

In [None]:
import dask.array as da
FULL = [
    "M2", "S2", "N2", "K2", "2N2", "L2", "T2", "R2", "NU2", "MU2", "EPS2", "LAMBDA2",  # Semi-diurnal (twice daily)
    "K1", "O1", "P1", "Q1", "J1", "S1",  # Diurnal (once daily)
    "MF", "MM", "MSF", "SA", "SSA", "MSQM", "MTM",  # Long period (fortnightly to annual)
    "M4", "MS4", "M6", "MN4", "N4", "S4", "M8", "M3", "MKS2",  # Short period (higher harmonics)
]

metrics = ["amplitude","phase"]

def analyze_node(ts_np: np.ndarray, time_index: pd.DatetimeIndex) -> np.ndarray:

    ts = pd.Series(ts_np, index=time_index, name="elev")
    df = pytides_to_df(pytide_get_coefs(ts, 60))
    df = reduce_coef_to_fes(df, cnst=FULL)
    # df = df.loc[constituents, metrics]  # enforce order
    return df.to_numpy()

def analyze_block(block: np.ndarray, time_block: np.ndarray) -> np.ndarray:
    time_index = pd.DatetimeIndex(time_block)
    return np.stack(
        [analyze_node(block[:, i], time_index) for i in range(block.shape[1])],
        axis=0
    )

nconst = len(FULL)
nmetrics = len(metrics)

results = da.map_blocks(
    analyze_block,
    elev.data,            # (Nt, Nnodes)
    elev["time"],
    dtype=float,
    drop_axis=0,          # drop time axis
    new_axis=[1, 2],      # add constituent, metric axes
    chunks=(elev.chunks[1], nconst, nmetrics)
)
results

In [None]:
coef_da = xr.DataArray(
    results,
    dims=("nSCHISM_hgrid_node", "constituent", "metric"),
    coords={
        "nSCHISM_hgrid_node": elev.nSCHISM_hgrid_node,
        "constituent": FULL,
        "metric": metrics,
        "lon": elev.SCHISM_hgrid_node_x,
        "lat": elev.SCHISM_hgrid_node_y,
    },
    name="tidal_coefs"
)
coef_da

In [None]:
from dask.diagnostics import ProgressBar
import warnings
warnings.filterwarnings("ignore")

with ProgressBar():
    coef_result = coef_da.compute()

In [None]:
coef_ds = coef_result.to_dataset(dim="metric")

In [None]:
coef_ds.to_netcdf("run4_tides.nc")

In [None]:
import hvplot.xarray

In [None]:
m2 = coef_ds.sel(constituent = "M2")
m2

In [None]:
# m2.hvplot.scatter(x='lon', y="lat",c="amplitude" )