# Mesh renumbering 

## Why renumbering ?


Essentially, renumbering can be important in order to optimize the post-processing of space/time datasets. 

By slicing the Dataset along the the node index, we can get a chunk of data that is: 
 * coherent spatially to certain region
 * that is much smaller than the whole dataset
 * that facilitates multi processing 


you need both `xarray-selafin` and `thalassa` for this notebook to work

In [None]:
%pip install xarray-selafin

In [None]:
import xarray as xr 
import thalassa
from thalassa import api
import holoviews as hv
hv.extension("bokeh")

we load a 7km mesh of the global ocean

In [None]:
ds = thalassa.open_dataset("meshes/global-v1.2.slf")
ds

In [None]:
import pandas as pd
import hvplot.pandas
import numpy as np
def plot_mesh(x, y): 
    df = pd.DataFrame({'x': x, 'y': y, 'id': np.arange(len(x))})
    im = df.hvplot.points(x='x', y='y', c='id',s=3)
    return im

x, y = ds.lon.values, ds.lat.values
plot_mesh(x,y).opts(width = 1200, height = 600,cmap='tab20c')

we see that the mesh indeice is not coherent spatially. 

to reorder/ renumber, we will first load a simplified version of the [world maritime borders](https://www.naturalearthdata.com/downloads/10m-physical-vectors/10m-physical-labels/) avaiable at this [gist](https://gist.github.com/tomsail/2fc6c0d9544f6354f9822576fb58b4f7): 

In [None]:
import geopandas as gpd
world_oceans = gpd.read_file("https://gist.githubusercontent.com/tomsail/2fc6c0d9544f6354f9822576fb58b4f7/raw/5864569a2f410b621ee07e92b782f21a8fbe4e6c/world_oceans.json")
world_oceans.hvplot(color = 'name', width = 1500, height = 900, cmap='tab20c')

the problem with this dataset is that it only covers the oceans, but not the continents. 
So depending on your meshing, some point might fall on land.

We'll remediate to this problem by using this [methodology](https://gis.stackexchange.com/questions/175599/buffer-neighbouring-polygons-without-overlap-using-qgis): 
1) Extract vertices from polygons, ensuring a unique field is kept as an attribute
2) Create a voronoi from these points
3) Buffer the original polygons by the required amount -- **not needed here**
4) Subtract the buffer polygons from the voronoi
5) Recombine/dissolve the remaining voronoi polygons on the unique attribute field
    

In [None]:
import geopandas as gpd
from shapely.ops import unary_union
import shapely
from scipy.spatial import Voronoi

In [None]:
def get_land_poly(gdf: gpd.GeoDataFrame):
    # 1. Extract vertices from polygons, ensuring a unique field is kept as an attribute
    points = []
    attributes = []
    for index, row in gdf.iterrows():
        poly = row['geometry']
        for geom in poly.geoms:
            for coord in geom.exterior.coords:
                points.append(shapely.Point(coord))
                attributes.append(row['id'])

    points_gdf = gpd.GeoDataFrame(data={'id': attributes}, geometry=points, crs= "EPSG:4326")
    # 2. Create a Voronoi diagram from these points
    coords = np.array([point.coords[0] for point in points])
    vor = Voronoi(coords)
    # 2.bis Convert Voronoi regions to polygons
    regions = [r for r in vor.regions if -1 not in r and r != []]
    voronoi_polys = [shapely.Polygon([vor.vertices[i] for i in region]) for region in regions]

    # 2.ter Create a GeoDataFrame from Voronoi polygons
    voronoi_gdf = gpd.GeoDataFrame(geometry=voronoi_polys, crs = "EPSG:4326")

    # 3. Subtract the buffer polygons from the voronoi
    result = gpd.overlay(voronoi_gdf, gdf, how='difference')
    # 5. Recombine/dissolve the remaining Voronoi polygons by the unique attribute field
    result_with_attr = gpd.sjoin(result, points_gdf, how='left', op='intersects')
    dissolved = result_with_attr.dissolve(by='id')

    # 6. Clip the dissolved polygons to the bounding box of the world WGS84
    bbox = (-180, -90, 180, 90)
    land = gpd.GeoDataFrame(data=dissolved, geometry = dissolved.clip_by_rect(*bbox), crs="EPSG:4326")
    return land

In [None]:
land = get_land_poly(world_oceans)

In [None]:
land_plot = land.hvplot(color = 'index').opts(cmap = 'tab20c', width = 1500, height = 900)
land_plot

In [None]:
df = land.merge(world_oceans, on='id')
df['geometry'] = df.apply(lambda row: unary_union([row['geometry_x'], row['geometry_y']]), axis=1)
whole = gpd.GeoDataFrame(data = df, geometry = df['geometry'], crs = "EPSG:4326")

drop unnecessary columns

In [None]:
whole = whole.drop(['geometry_x', 'geometry_y'], axis=1)
whole.to_file('assets/world_oceans_land.json')

In [None]:
whole.hvplot(color='name').opts(cmap = 'tab20c', width = 1500, height = 900) * land_plot

there are still some imperfections (i.e. overlaps between the polygons). A simplified version has been done using QGIS and is avaiable on this [gist](https://gist.github.com/tomsail/2fa52d9667312b586e7d3baee123b57b):

In [None]:
final = gpd.read_file('https://gist.githubusercontent.com/tomsail/2fa52d9667312b586e7d3baee123b57b/raw/dcda4d7adfc422481cdaf2b74a9dee53e0a505c0/world_maritime_sectors.json')

In [None]:
countries = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
map_ = countries.hvplot().opts(alpha=0.1, color='white',line_alpha=0.9)
final.hvplot(color='ocean').opts(cmap = 'tab20c', width = 1500, height = 900) * map_

to check if points are inside polygon, we need to install numba

In [None]:
%pip install numba

In [None]:
from numba import jit, njit
import numba
import numpy as np 

@jit(nopython=True)
def pointinpolygon(x,y,poly):
    n = len(poly)
    inside = False
    p2x = 0.0
    p2y = 0.0
    xints = 0.0
    p1x,p1y = poly[0]
    for i in numba.prange(n+1):
        p2x,p2y = poly[i % n]
        if y > min(p1y,p2y):
            if y <= max(p1y,p2y):
                if x <= max(p1x,p2x):
                    if p1y != p2y:
                        xints = (y-p1y)*(p2x-p1x)/(p2y-p1y)+p1x
                    if p1x == p2x or x <= xints:
                        inside = not inside
        p1x,p1y = p2x,p2y

    return inside


@njit(parallel=True)
def parallelpointinpolygon(points, polygon):
    D = np.empty(len(points), dtype=numba.boolean) 
    for i in numba.prange(0, len(D)):
        D[i] = pointinpolygon(points[i,0], points[i,1], polygon)
    return D    

the function 3 functions below consist in:

 * `reorder_nodes_within_region`: reorder the nodes within a given region based on the computed weights.
 * `remap_connectivity`: remap the connectivity of the triangles to reflect the new node ordering.
 * `reorder_mesh`: main functions that translates input mesh to "ordered" mesh.

In [None]:
from typing import Tuple, List, Iterable
import numpy_indexed as npi
import geopandas as gpd
from inpoly import inpoly2 ## to compare with inpoly2

def remap_connectivity(
        tri: np.ndarray, 
        mapping: np.ndarray
    ) -> np.ndarray:
    """Remap the connectivity of a triangular mesh based on the new node order.

    Args:
        tri: The original connectivity array of the triangular mesh.
        mapping: The array that maps old node indices to new ones.

    Returns:
        The remapped connectivity array for the triangular mesh.
    """    
    remapped_nodes = np.arange(len(mapping))
    remapped_triface_nodes = np.c_[
        npi.remap(tri[:, 0], mapping, remapped_nodes),
        npi.remap(tri[:, 1], mapping, remapped_nodes),
        npi.remap(tri[:, 2], mapping, remapped_nodes),
    ]
    return remapped_triface_nodes

## from oceanmesh
def get_poly_edges(poly):
    col1 = np.arange(0, len(poly) - 1)
    col2 = np.arange(1, len(poly))
    return np.vstack((col1, col2)).T


def reorder_nodes(
        x: np.ndarray, 
        y: np.ndarray, 
        region_polygon: gpd.GeoDataFrame, 
        order_wgts: np.ndarray, 
        method:str = "inpoly"
    ) -> np.ndarray:
    """Reorder nodes within a given region based on their weights.

    Args:
        x: The x-coordinates of the nodes.
        y: The y-coordinates of the nodes.
        region_polygon: The polygon representing the region.
        order_wgts: The weights for ordering the nodes.
        method: The method for checking if a point is inside the polygon. ("inpoly", "numba" or "bbox")

    Returns:
        The indices of the reordered nodes within the given region.
    """    
    if method == "bbox": # check inside bbox (faster)
        bbox = region_polygon.bounds
        points_in_region_final = (y >= bbox[1]) & (y <= bbox[3]) & (x >= bbox[0]) & (x <= bbox[2])
    else: # check inside polygon (more accurate)
        if region_polygon.geom_type == "Polygon":
            polygon = np.array(region_polygon.exterior.coords)
            if method == "inpoly":
                e = get_poly_edges(polygon)
                points_in_region_final, _ = inpoly2(np.vstack((x,y)).T, polygon, e)    
            elif method == "numba":
                points_in_region_final = parallelpointinpolygon(np.vstack((x,y)).T, polygon)
        else : 
            points_in_region_final = np.zeros(len(x), dtype=np.bool_)
            for poly in region_polygon.geoms: 
                polygon = list(poly.exterior.coords)
                if method == "inpoly":
                    e = get_poly_edges(polygon)
                    points_in_region, _ = inpoly2(np.vstack((x,y)).T, polygon, e)    
                elif method == "numba":
                    points_in_region = parallelpointinpolygon(np.vstack((x,y)).T, polygon)
                points_in_region_final = np.logical_or(points_in_region_final, points_in_region)
    indices_in_region = np.where(points_in_region_final)[0]
    order_wgts_in_region = order_wgts[indices_in_region]
    idx_sort = np.argsort(order_wgts_in_region)
    mapping = np.arange(len(x))
    mapping[indices_in_region] = indices_in_region[idx_sort]
    return indices_in_region[idx_sort]


def reorder_mesh(
        x: np.ndarray, 
        y: np.ndarray, 
        tri:np.ndarray, 
        regions: gpd.GeoDataFrame
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]:
    """Reorder the mesh nodes and remap the connectivity for each region.

    Args:
        mesh: The dataset representing the mesh.
        regions: A GeoDataFrame representing the regions.

    Returns:
        A tuple containing the reordered x-coordinates, y-coordinates, 
        remapped connectivity, and the global sorting indices.
    """    # 1 normalise
    normalized_lon = x - np.min(x)
    normalized_lat = y - np.min(y)
    # 2 compute ordering
    order_wgts = (normalized_lon) + (180-normalized_lat) * 360
    # 3 test point in regions and fill in mapping / sorted indices
    global_sorted = []
    for ir, region in regions.iterrows():
        region_polygon = region['geometry']
        # 4. Reorder the nodes within each region 
        sorted_indices = reorder_nodes(x,y, region_polygon, order_wgts, method='numba')
        global_sorted.extend(sorted_indices)
    # 5. Remap the connectivity 
    tri_out = remap_connectivity(tri, np.array(global_sorted))
    return x[global_sorted], y[global_sorted], tri_out, global_sorted

In [None]:
x, y, tri = ds.lon.values, ds.lat.values, ds.triface_nodes
x_, y_, tri_, map_  = reorder_mesh(x, y, tri, final)

finalise and save the new mesh dataset:
 * using thalassa's `GENERIC` Format: 

In [None]:
mesh_out = xr.Dataset({
    'lon': (['node'], x_),
    'lat': (['node'], y_),
    'triface_nodes': (['triface', 'three'], tri_),
    'depth': (['node'], ds.B.isel(time=0).values[map_]),
})
mesh_out

In [None]:
plot_mesh(x_,y_).opts(width = 1200, height = 600,cmap='tab20c')

check the depth assignation

In [None]:
import matplotlib.pyplot as plt
def is_overlapping(tris, meshx):
    PIR = 180
    x1, x2, x3 = meshx[tris].T
    return np.logical_or(abs(x2 - x1) > PIR, abs(x3 - x1) > PIR, abs(x3 - x2) > PIR)

m = is_overlapping(tri_ ,x_)
fig, ax = plt.subplots(1,1, figsize = (16,9))
ax.tricontourf(mesh_out.lon.values, mesh_out.lat.values, mesh_out.triface_nodes[~m], mesh_out.depth, )
plt.show()