In [None]:
import io
import math
import os
import pathlib

import holoviews as hv
import hvplot.pandas
import numpy as np
import pandas as pd
import pymap3d
import pyposeidon.mesh as pmesh

hv.extension("bokeh")
np.set_printoptions(suppress=True)

In [None]:
def parse_hgrid_nodes(path: os.PathLike[str] | str) -> pd.DataFrame:
    with open(path, "rb") as fd:
        _ = fd.readline()
        _, no_points = map(int, fd.readline().strip().split(b" "))
        content = io.BytesIO(b''.join(next(fd) for _ in range(no_points)))
        nodes = pd.read_csv(
            content,
            engine="pyarrow",
            sep="\t",
            header=None,
            names=["lon", "lat", "depth"],
            index_col=0
        )
    nodes = nodes.reset_index(drop=True)
    return nodes
    
def parse_hgrid_elements3(path: os.PathLike[str] | str) -> pd.DataFrame:
    with open(path, "rb") as fd:
        _ = fd.readline()
        no_elements, no_points = map(int, fd.readline().strip().split(b" "))
        for _ in range(no_points):
            next(fd) 
        content = io.BytesIO(b''.join(next(fd) for _ in range(no_elements)))
        elements = pd.read_csv(
            content,
            engine="pyarrow",
            sep="\t",
            header=None,
            names=["no_nodes", "n1", "n2", "n3"],
            index_col=0
        )
    elements = elements.assign(
        n1=elements.n1 - 1,
        n2=elements.n2 - 1,
        n3=elements.n3 - 1,
    ).reset_index(drop=True)
    return elements

In [None]:
def get_skews_and_base_cfls(lons, lats, depths) -> np.ndarray:
    # The shape of each one of the input arrays needs to be (3, <no_triangles>)
    #ell = pymap3d.Ellipsoid.from_name("wgs84")
    ell = pymap3d.Ellipsoid(6378206.4, 6378206.4, "schism", "schism")
    local_x, local_y, _ = pymap3d.geodetic2enu(lats, lons, depths, lats[0], lons[0], depths[0], ell=ell)
    areas = (local_x[1] * local_y[2] - local_x[2] * local_y[1]) * 0.5
    rhos = np.sqrt(areas / np.pi)
    max_sides = np.maximum(
        np.sqrt(local_x[1] ** 2 + local_y[1] ** 2),
        np.sqrt(local_x[2] ** 2 + local_y[2] ** 2),
        np.sqrt((local_x[2] - local_x[1]) ** 2 + (local_y[2] - local_y[1]) ** 2),
    )
    skews = max_sides / rhos
    base_cfls = np.sqrt(9.81 * np.maximum(0.1, depths.mean(axis=0))) / rhos / 2
    return skews, base_cfls

def get_skews_and_base_cfls_from_path(path: os.PathLike[str] | str) -> np.ndarray:
    nodes = parse_hgrid_nodes(path)
    elements = parse_hgrid_elements3(path)
    tri = elements[["n1", "n2", "n3"]].values
    lons = nodes.lon.values[tri].T
    lats = nodes.lat.values[tri].T
    depths = nodes.depth.values[tri].T
    skews, base_cfls = get_skews_and_base_cfls(lons=lons, lats=lats, depths=depths)
    return skews, base_cfls

In [None]:
path = "/home/panos/Prog/poseidon/seareport_meshes/meshes/global-v0.1.gr3"
path = "/home/panos/Prog/poseidon/seareport_meshes/meshes/global-v0.gr3"
path = "/home/panos/Prog/git/schism/src/Utility/Grid_Scripts/hgrid.gr3"
path = "/home/panos/Prog/poseidon/seareport_meshes/meshes/global-v0.2.gr3"
skews, base_cfls = get_skews_and_base_cfls_from_path(path)

In [None]:
CFL_THRESHOLD = 0.4
for dt in (1, 50, 75, 100, 120, 150, 200, 300, 400, 600, 900, 1200, 1800, 3600):
    violations = (base_cfls * dt < CFL_THRESHOLD).sum()
    print(f"{dt:>4d} {violations:>12d} {violations / len(base_cfls) * 100:>8.2f}%")

In [None]:
pd.DataFrame({"skew": skews}).describe()

In [None]:
df = pd.DataFrame({"cfl": base_cfls * 400})
df[df.cfl < 0.4].describe()
df[df.cfl < 0.4].hvplot.hist()

In [None]:
nodes = parse_hgrid_nodes(path)
elements = parse_hgrid_elements3(path)
elements = elements.assign(base_cfl=base_cfls)
elements.head()

In [None]:
min_cfl_per_node = pd.concat([
    elements[["n1", "base_cfl"]].groupby(["n1"]).base_cfl.min(),
    elements[["n2", "base_cfl"]].groupby(["n2"]).base_cfl.min(),
    elements[["n3", "base_cfl"]].groupby(["n3"]).base_cfl.min(),
], axis=1).min(axis=1)
min_cfl_per_node.head()

In [None]:
dt = 600
df = nodes.assign(
    cfl=min_cfl_per_node * dt,
    # CFL_violation nodes have a value of 1 if there is no violation and 4 if there is a violation. 
    # We do this in order to plot the points with a different size
    cfl_violation=((min_cfl_per_node * dt < CFL_THRESHOLD) * 3) + 1   
)
df.head()

In [None]:
df.hvplot.points(
    'lon', 
    'lat',
    c="cfl_violation",
    cmap="colorblind",
    geo=True,
    tiles=True,
).options(
    width=900, height=600
).opts(
    hv.opts.Points(size=hv.dim('cfl_violation'))
)