In [None]:
import io
import math
import os
import pathlib

import holoviews as hv
import hvplot.pandas
import numpy as np
import pandas as pd
import pymap3d
import xarray as xr

hv.extension("bokeh")
np.set_printoptions(suppress=True)

In [None]:
def get_skews_and_base_cfls(lons, lats, depths) -> np.ndarray:
    # The shape of each one of the input arrays needs to be (3, <no_triangles>)
    #ell = pymap3d.Ellipsoid.from_name("wgs84")
    ell = pymap3d.Ellipsoid(6378206.4, 6378206.4, "schism", "schism")
    local_x, local_y, _ = pymap3d.geodetic2enu(lats, lons, depths, lats[0], lons[0], depths[0], ell=ell)
    areas = (local_x[1] * local_y[2] - local_x[2] * local_y[1]) * 0.5
    rhos = np.sqrt(areas / np.pi)
    max_sides = np.maximum(
        np.sqrt(local_x[1] ** 2 + local_y[1] ** 2),
        np.sqrt(local_x[2] ** 2 + local_y[2] ** 2),
        np.sqrt((local_x[2] - local_x[1]) ** 2 + (local_y[2] - local_y[1]) ** 2),
    )
    skews = max_sides / rhos
    base_cfls = np.sqrt(9.81 * np.maximum(0.1, depths.mean(axis=0))) / rhos / 2
    return skews, base_cfls

def get_skews_and_base_cfls_from_path(path: os.PathLike[str] | str) -> np.ndarray:
    ds = xr.open_dataset(path, engine='selafin')
    tri = ds.attrs['ikle2'] - 1
    lons = ds.x.values[tri].T
    lats = ds.y.values[tri].T
    depths = - ds.B.isel(time=0).values[tri].T
    skews, base_cfls = get_skews_and_base_cfls(lons=lons, lats=lats, depths=depths)
    return skews, base_cfls

In [None]:
file = "/home/tomsail/Documents/work/models/meshes/slf/v1p2.slf"
file = "/home/tomsail/Documents/work/models/meshes/slf/v2p1.slf"
ds = xr.open_dataset(file, engine='selafin')
skews, base_cfls = get_skews_and_base_cfls_from_path(file)

In [None]:
CFL_THRESHOLD = 0.4
for dt in (1, 50, 75, 100, 120, 150, 200, 300, 400, 600, 900, 1200, 1800, 3600):
    violations = (base_cfls * dt < CFL_THRESHOLD).sum()
    print(f"{dt:>4d} {violations:>12d} {violations / len(base_cfls) * 100:>8.2f}%")

In [None]:
pd.DataFrame({"skew": skews}).describe()

In [None]:
df = pd.DataFrame({"cfl": base_cfls * 400})
df[df.cfl < 0.4].describe()
df[df.cfl < 0.4].hvplot.hist()

In [None]:
tri = ds.attrs['ikle2'] - 1
nodes = pd.DataFrame(np.vstack((ds.x, ds.y, ds.B.isel(time=0))).T, columns=["lon", "lat", "depth"])
elements = pd.DataFrame(np.vstack( (np.ones(len(tri))* 3, tri.T)).T , columns=["no_nodes", "n1", "n2", "n3"])
elements = elements.assign(base_cfl=base_cfls)
elements.head()

In [None]:
min_cfl_per_node = pd.concat([
    elements[["n1", "base_cfl"]].groupby(["n1"]).base_cfl.min(),
    elements[["n2", "base_cfl"]].groupby(["n2"]).base_cfl.min(),
    elements[["n3", "base_cfl"]].groupby(["n3"]).base_cfl.min(),
], axis=1).min(axis=1)
min_cfl_per_node.head()

In [None]:
dt = 200
df = nodes.assign(
    cfl=min_cfl_per_node * dt,
    # CFL_violation nodes have a value of 1 if there is no violation and 4 if there is a violation. 
    # We do this in order to plot the points with a different size
    cfl_violation=((min_cfl_per_node * dt < CFL_THRESHOLD) * 3) + 1   
)
df.head()

In [None]:
plot = df[df.cfl_violation == 4].hvplot.points(
    'lon', 
    'lat',
    c="depth",
    cmap="jet",
    geo=True,
    tiles="EsriImagery",
).options(
    width=1200, height=900
)
len(df[df.cfl_violation == 4])

In [None]:
plot