# CS555 Project: Solving Shallow Water Equation with Finite Volume on unstructured mesh

The project aims to solve shallow water equation with FV on unstructured mesh. (cell-centered)

A link to Google Colab with all demo animation:

https://colab.research.google.com/drive/1gbIPeLaul1bmvL2sxENO0wzs1RR6pjeM?usp=sharing

In [None]:
import numpy as np
import numpy.linalg as la

from scipy import sparse

from matplotlib import tri
from matplotlib import rcParams
import matplotlib.pyplot as plt
from  matplotlib import animation
from matplotlib.colors import Normalize
from matplotlib import cm
from tqdm import tqdm

from IPython.display import HTML

## FV on unstructured mesh

cell centered implemenation

### build mesh

In [None]:
class TriMesh:
    def __init__(self, filename, filetype, refine=0):
        V, E = self.read_mesh(filename, filetype)
        t = tri.Triangulation(V[:, 0], V[:, 1], E)

        if refine > 0:
            refiner = tri.UniformTriRefiner(t)
            t = refiner.refine_triangulation(subdiv=refine)
            V = np.vstack((t.x, t.y)).T
            E = t.triangles

        ncell = E.shape[0]
        cells = np.mean(V[E], axis = 1)
        areas = np.zeros(ncell)
        faces = t.edges
        nface = faces.shape[0]
        normals = np.zeros((nface, 2))
        fid2cid = np.zeros((nface, 2), dtype=np.int32)
        face_type = np.zeros(nface, dtype=np.int8)
        l = np.zeros(nface)

        vid_to_fid = {}

        # map vertex pair to face id
        for fid in range(nface):
            vid0, vid1 = faces[fid]
            v_key = (vid0, vid1) if vid0 < vid1 else (vid1, vid0)
            vid_to_fid[v_key] = fid

        for cid, vids in enumerate(t.triangles):
            # link cell id with face id
            # boundary face has two identical cell id
            for i, j in zip([0, 1, 2], [1, 2, 0]):
                v_key = (vids[i], vids[j]) if vids[i] < vids[j] else (vids[j], vids[i])
                fid = vid_to_fid[v_key]
                if face_type[fid] == 0:
                    fid2cid[fid] = np.array([cid, cid])
                    face_type[fid] = 1
                else:
                    old_cid = fid2cid[fid]
                    if cid < old_cid[0]:
                        fid2cid[fid, 0] = cid
                    else:
                        fid2cid[fid, 1] = cid

                    face_type[fid] = 2

            # compute cell volume (area)
            areas[cid] = 0.5 * np.abs(la.det(np.vstack((np.ones(3), V[vids].T))))


        for i, face in enumerate(faces):
            v0, v1 = V[face[0]], V[face[1]]
            tan = v0 - v1
            l[i] = la.norm(tan)
            n = np.array([-tan[1], tan[0]])

            # TODO: double check normal direction
            # normal point from c0 to c1
            c0 = cells[fid2cid[i, 0]]
            if n.dot((v0+v1)/2 - c0) < 0:
                normals[i] = -n / la.norm(n)
            else:
                normals[i] = n / la.norm(n)


        self.t = t
        self.cells = cells
        self.normals = normals
        self.fid2cid = fid2cid
        self.face_len = l
        self.areas = areas
        self.bndf_mask = (face_type == 1)
        self.ncell = ncell
        self.nface = nface

    def read_mesh(self, name, type='.'):
        if type == '.ply2':
            with open(name+type) as f:
                vnum = int(next(f))
                enum = int(next(f))
                v = [[float(x) for x in next(f).split()] for i in range(vnum)]
                e = [[int(x) for x in next(f).split()] for i in range(enum)]

            vertices = np.array(v)[:, :2]
            triangles = np.array(e, dtype=int)[:, 1:]
            # save file

        elif type == '.':
            vertices = np.loadtxt('mesh.v')
            triangles = np.loadtxt('mesh.e', dtype=int)

        else:
            raise "Unknown mesh file type."

        return vertices, triangles

    def plot(self, show_one_edge=False):
        fig, ax = plt.subplots()
        ax.set_aspect(1)
        ax.triplot(self.t.x, self.t.y, self.t.triangles, color='gray')
        print("Total cell number in triangle mesh: ", self.ncell)
        if show_one_edge:
            fid = 0
            vtx = np.vstack((self.t.x, self.t.y)).T
            v0 = vtx[self.t.edges[fid, 0]]
            v1 = vtx[self.t.edges[fid, 1]]
            c0 = self.cells[self.fid2cid[fid, 0]]
            c1 = self.cells[self.fid2cid[fid, 1]]
            ax.scatter(c0[0], c0[1], color='green', label='c0')
            ax.scatter(c1[0], c1[1], color='blue', label='c1')
            ax.scatter(v0[0], v0[1], color='red', marker='.')
            ax.scatter(v1[0], v1[1], color='red', marker='.')
            ax.quiver(*((v0 + v1) / 2), self.normals[fid, 0], self.normals[fid, 1], headaxislength=2, headlength=3, width=0.003, label='normal')
        
        plt.show()


### Solve on mesh

In [None]:
class MeshSolver:
    def __init__(self, mesh: TriMesh, dt=5e-3, g = 1):
        self.mesh = mesh
        self.u = np.zeros((3, mesh.ncell))
        self.time = 0
        self.dt = dt
        self.g = g


    def init_height(self, func):
        cells = self.mesh.cells
        self.u[0] = func(cells[:, 0], cells[:, 1])


    def wall_bc(self):
        u_by_face = self.u[:,self.mesh.fid2cid[:, 0]]
        g = self.g
        mask = self.mesh.bndf_mask
        ub = u_by_face[:, mask]
        hb = ub[0]
        vxb = ub[1] / hb
        vyb = ub[2] / hb

        normalsb = self.mesh.normals[mask]
        proj_len = vxb * normalsb[:,0] + vyb * normalsb[:,1]
        vx_inv = vxb - 2 * proj_len * normalsb[:,0]
        vy_inv = vyb - 2 * proj_len * normalsb[:,1]
        ub_inv = np.array([hb, hb*vx_inv, hb*vy_inv])
        Fxb_inv = np.array([hb*vx_inv, hb*vx_inv**2 + g/2 * hb**2, hb*vx_inv*vy_inv])
        Fyb_inv = np.array([hb*vy_inv, hb*vx_inv*vy_inv, hb*vy_inv**2 + g/2 * hb**2])

        F_bnd = Fxb_inv * normalsb[:,0] + Fyb_inv * normalsb[:,1]
        jmp_bnd = ub_inv - ub

        return F_bnd, jmp_bnd


    def solve_step(self):
        dt = self.dt
        u = self.u
        nface = self.mesh.nface
        g = self.g
        h = u[0]
        vx = u[1] / h
        vy = u[2] / h

        Fx = np.array([h*vx, h*vx**2 + g/2 * h**2, h*vx*vy])
        Fy = np.array([h*vy, h*vx*vy, h*vy**2 + g/2 * h**2])

        c0 = self.mesh.fid2cid[:, 0]
        c1 = self.mesh.fid2cid[:, 1]
        normals = self.mesh.normals
        l = self.mesh.face_len
        areas = self.mesh.areas
        mask = self.mesh.bndf_mask


        F0 = Fx[:, c0] * normals[:,0] + Fy[:, c0] * normals[:, 1]
        F1 = Fx[:, c1] * normals[:,0] + Fy[:, c1] * normals[:, 1]

        spd0 = np.abs([vx[c0] * normals[:, 0] + vy[c0] * normals[:,1]]) + np.sqrt(g * h[c0])
        spd1 = np.abs([vx[c1] * normals[:, 0] + vy[c1] * normals[:,1]]) + np.sqrt(g * h[c1])
        spd = np.maximum(spd0, spd1)

        jmp = u[:,c1] - u[:,c0]

        F_bnd, jmp_bnd = self.wall_bc()
        F1[:, mask] = F_bnd
        jmp[:, mask] = jmp_bnd
        F = (F0+F1)/ 2 - spd * jmp / 2

        flux0 = - dt * F * l / areas[c0]
        flux1 = np.zeros_like(flux0)
        flux1[:, ~mask] = (dt * F * l / areas[c1])[:, ~mask]


        row_idx = np.tile(np.repeat(np.arange(3), nface), 2) #[0 0...0 1 1...1 2 2...2 0 0...0 1 1...1 2 2...2 ]
        col_idx = np.concatenate((np.tile(c0, 3), np.tile(c1, 3))) # [-c0- -c0- -c0- -c1- -c1- -c1-]
        delta = np.concatenate((flux0.ravel(), flux1.ravel()))

        self.u += sparse.coo_matrix((delta, (row_idx, col_idx))).toarray()
        self.time += dt
        return self.time


## Reference: FV with uniform grid

In [None]:
class UniformGrid:
    def __init__(self, xrange, yrange, mask_func):
        x_lo, x_hi, nx = xrange
        y_lo, y_hi, ny = yrange
        x, hx = np.linspace(x_lo, x_hi, nx, endpoint=False, retstep=True)
        y, hy = np.linspace(y_lo, y_hi, ny, endpoint=False, retstep=True)
        # Shift x and y by half cell size. store value at the cell center.
        x += hx / 2
        y += hy / 2

        Kx = np.arange(0, nx)
        Kxm1 = np.roll(Kx, 1)
        Kxp1 = np.roll(Kx, -1)
        Ky = np.arange(0, ny)
        Kym1 = np.roll(Ky, 1)
        Kyp1 = np.roll(Ky, -1)

        u = np.zeros((3, ny, nx))

        xx, yy = np.meshgrid(x,y)
        # print(xx.shape)
        ic = lambda X, Y: np.where(xx**2 + yy**2 < 0.09, 3, 1)
        u[0] = ic(xx, yy)

        mask= mask_func(xx, yy)

        self.xx = xx
        self.yy = yy
        self.hx = hx
        self.hy = hy
        self.nx = nx
        self.ny = ny
        self.Kxp1 = Kxp1
        self.Kxm1 = Kxm1
        self.Kyp1 = Kyp1
        self.Kym1 = Kym1
        self.mask = mask
        self.x_bnd_idx = mask.astype(np.int32) - mask[:,Kxp1].astype(np.int32)
        self.y_bnd_idx = mask.astype(np.int32) - mask[Kyp1].astype(np.int32)

    def plot(self):
        print("Total active cell number in uniform grid: ", self.mask.sum())


In [None]:
class GridSolver:
    def __init__(self, grid: UniformGrid, dt = 5e-3, g = 1):
        self.dt = dt
        self.grid = grid
        self.time = 0
        self.g = g
        self.u = np.zeros((3, grid.ny, grid.nx))
        

    def init_height(self, func):
        self.u[0] = func(self.grid.xx, self.grid.yy)


    def solve_step(self):
        u = self.u
        g = self.g

        x_bnd_idx = self.grid.x_bnd_idx
        y_bnd_idx = self.grid.y_bnd_idx

        Kxp1 = self.grid.Kxp1
        Kyp1 = self.grid.Kyp1
        Kxm1 = self.grid.Kxm1
        Kym1 = self.grid.Kym1

        mask = self.grid.mask

        h, mx, my = u[0], u[1], u[2]
        vx = mx / h
        vy = my / h

        # x normal direction points towards right

        Fx = np.array([h*vx, h*vx**2 + g/2 * h**2, h*vx*vy])
        Fy = np.array([h*vy, h*vx*vy, h*vy**2 + g/2 * h**2])
        F_avgx = (Fx + Fx[:,:,Kxp1])/2
        F_avgy = (Fy + Fy[:,Kyp1,:])/2
        
        Fx_inv = np.array([-h*vx, h*vx**2 + g/2 * h**2, -h*vx*vy])
        Fy_inv = np.array([-h*vy, -h*vx*vy, h*vy**2 + g/2 * h**2])
        F_avgx[:,x_bnd_idx == 1] = ((Fx + Fx_inv)/2)[:,x_bnd_idx == 1]
        F_avgx[:,x_bnd_idx == -1] = ((Fx + Fx_inv)/2)[:,:,Kxp1][:,x_bnd_idx == -1]
        F_avgy[:,y_bnd_idx == 1] = ((Fy + Fy_inv)/2)[:,y_bnd_idx == 1]
        F_avgy[:,y_bnd_idx == -1] = ((Fy + Fy_inv)/2)[:,Kyp1,:][:,y_bnd_idx == -1]


        max_eigval_x = np.abs(vx) + np.sqrt(g * h)
        max_eigval_y = np.abs(vy) + np.sqrt(g * h)
        spd_x = np.maximum(max_eigval_x, max_eigval_x[:, Kxp1])
        spd_y = np.maximum(max_eigval_y, max_eigval_y[Kyp1,:])
        spd_x[x_bnd_idx == 1] = max_eigval_x[x_bnd_idx == 1]
        spd_x[x_bnd_idx == -1] = max_eigval_x[:,Kxp1][x_bnd_idx == -1]
        spd_y[y_bnd_idx == 1] = max_eigval_y[y_bnd_idx == 1]
        spd_y[y_bnd_idx == -1] = max_eigval_y[Kyp1,:][y_bnd_idx == -1]


        jmp_x = u[:,:,Kxp1] - u
        jmp_y = u[:,Kyp1,:] - u

        jmp_x[:,x_bnd_idx == 1] = np.array([h-h, -2*mx, my-my])[:,x_bnd_idx == 1]
        jmp_x[:,x_bnd_idx == -1] = np.array([h-h, 2*mx[:,Kxp1], my-my])[:,x_bnd_idx == -1]
        jmp_y[:,y_bnd_idx == 1] = np.array([h-h, mx-mx, -2*my])[:,y_bnd_idx == 1]
        jmp_y[:,y_bnd_idx == -1] = np.array([h-h, mx-mx, 2*my[Kyp1,:]])[:,y_bnd_idx == -1]

        flux_x = F_avgx - spd_x * jmp_x / 2
        flux_y = F_avgy - spd_y * jmp_y / 2

        u[:,mask] -= (self.dt/self.grid.hx * (flux_x - flux_x[:, :, Kxm1]) + self.dt/self.grid.hy * (flux_y - flux_y[:, Kym1]))[:,mask]
        self.time += self.dt
        return self.time


In [None]:
def simulate_and_plot(m_solver: MeshSolver, g_solver: GridSolver, target_time, dt, fps, speed_up, norm_h, norm_v):
    frames = []
    frame_timer = 0
    total_steps = int(target_time / dt)

    if g_solver != None:
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        for ax in axes.ravel():
            ax.set_aspect(1)

        axes[0,0].set_title("Water height (mesh)")
        axes[0,1].set_title("velocity magnitude (mesh)")
        axes[1,0].set_title("Water height (uniform gird)")
        axes[1,1].set_title("velocity magnitude (uniform grid)")
        fig.colorbar(cm.ScalarMappable(norm=norm_h), ax=axes[0,0], location='bottom')
        fig.colorbar(cm.ScalarMappable(norm=norm_v), ax=axes[0,1], location='bottom')
        fig.colorbar(cm.ScalarMappable(norm=norm_h), ax=axes[1,0], location='bottom')
        fig.colorbar(cm.ScalarMappable(norm=norm_v), ax=axes[1,1], location='bottom')

    else:
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        for ax in axes:
            ax.set_aspect(1)

        axes[0].set_title("Water height (mesh)")
        axes[1].set_title("velocity magnitude (mesh)")

        fig.colorbar(cm.ScalarMappable(norm=norm_h), ax=axes[0], location='bottom')
        fig.colorbar(cm.ScalarMappable(norm=norm_v), ax=axes[1], location='bottom')



    for i in tqdm(range(total_steps)):
        if m_solver.time >= frame_timer:
            frame_timer += speed_up/fps
            vx = m_solver.u[1] / m_solver.u[0]
            vy = m_solver.u[2] / m_solver.u[0]

            if g_solver != None:
                im_h_m = axes[0,0].tripcolor(m_solver.mesh.t, m_solver.u[0], animated=True, norm=norm_h)
                im_v_m = axes[0,1].tripcolor(m_solver.mesh.t, np.sqrt(vx**2 + vy**2), animated=True, norm=norm_v)
                vx_g = g_solver.u[1] / g_solver.u[0]
                vy_g = g_solver.u[2] / g_solver.u[0]
                im_h_g = axes[1,0].pcolormesh(g_solver.grid.xx, g_solver.grid.yy, np.where(g_solver.grid.mask, g_solver.u[0], np.nan), animated=True, norm=norm_h)
                im_v_g = axes[1,1].pcolormesh(g_solver.grid.xx, g_solver.grid.yy, np.where(g_solver.grid.mask, np.sqrt(vx_g**2 + vy_g**2), np.nan), animated=True, norm=norm_v)
                frames.append([im_h_m, im_v_m, im_h_g, im_v_g])
            else:
                im_h = axes[0].tripcolor(m_solver.mesh.t, m_solver.u[0], animated=True, norm=norm_h)
                im_v = axes[1].tripcolor(m_solver.mesh.t, np.sqrt(vx**2 + vy**2), animated=True, norm=norm_v)
                frames.append([im_h, im_v])

        m_solver.solve_step()
        if g_solver != None:
            g_solver.solve_step()


    print("start creating animation...")
    rcParams['animation.embed_limit'] = 64
    plt.close()
    return animation.ArtistAnimation(fig, frames, interval=1e3/fps, blit=True)



## Numerical Examples

(Output animation maybe extremely slow.).

### Example1: Circle domain

In [None]:
# !wget https://githubraw.com/ryan42210/FiniteVolumeOnMesh/main/assets/circle.ply2

In [None]:
dir_path = '../assets/'

# circle
def circle(X, Y):
    mask = np.where(X**2 + Y**2 < 1**2, True, False)
    return mask

norm_h = Normalize(1, 1.3)
norm_v = Normalize(0, 0.3)
ic = lambda X, Y: np.where(X**2 + Y**2 < 0.3**2, 3, 1)
target_time = 5
dt = 5e-3
g = 1

tri_mesh = TriMesh(dir_path + 'circle', '.ply2')
tri_mesh.plot()
mesh_solver = MeshSolver(tri_mesh, dt, g)
mesh_solver.init_height(ic)
grid = UniformGrid((-1.05,1.05,35), (-1.05,1.05,35),circle)
grid.plot()
grid_solver = GridSolver(grid, dt, g)
grid_solver.init_height(ic)

fps = 25
speed_up = 1
ani = simulate_and_plot(mesh_solver, grid_solver, target_time, dt, fps, speed_up, norm_h, norm_v)
print("output animation...")
HTML(ani.to_jshtml())


### Example2: Rectangle domain: aligned with axis or tilted

First we set up a rectangle with all four edges aligned with the axises.

Simulate with similar cell number, all other parameters are the same.

In [None]:
# !wget https://githubraw.com/ryan42210/FiniteVolumeOnMesh/main/assets/horizontal.ply2
# !wget https://githubraw.com/ryan42210/FiniteVolumeOnMesh/main/assets/tilt.ply2

In [None]:
# horizontal
def horizontal(X, Y):
    a = (Y > 0) & (Y < np.sqrt(2))
    b = (X < 5*np.sqrt(2)) & (X > 0)
    mask = np.where(a & b, True, False)
    return mask

norm_h = Normalize(1, 1.5)
norm_v = Normalize(0, 0.5)
ic = lambda X, Y: np.where(X < np.sqrt(2)/2, 4, 1)
target_time = 10
dt = 5e-3
g = 10

tri_mesh = TriMesh(dir_path + 'horizontal', '.ply2')
tri_mesh.plot()
mesh_solver = MeshSolver(tri_mesh, dt, g)
mesh_solver.init_height(ic)
grid = UniformGrid((-0.05,5*np.sqrt(2)+0.05,75), (0,np.sqrt(2),15),horizontal)
grid.plot()
grid_solver = GridSolver(grid, dt, g)
grid_solver.init_height(ic)

fps = 25
speed_up = 1
ani = simulate_and_plot(mesh_solver, grid_solver, target_time, dt, fps, speed_up, norm_h, norm_v)
print("output animation...")
HTML(ani.to_jshtml())

Then we rotate the rectangle by 45 degree and compare two solver again.

In [None]:
# tilt
def tilt(X, Y):
    a = (X + Y > -1) & (X + Y < 1)
    b = (X - Y < 5) & (X - Y > -5)
    mask = np.where(a & b, True, False)
    return mask

norm_h = Normalize(1, 1.5)
norm_v = Normalize(0, 0.5)
ic = lambda X, Y: np.where(X - Y < -4, 4, 1)
target_time = 10
dt = 5e-3
g = 10

tri_mesh = TriMesh(dir_path + 'tilt', '.ply2')
tri_mesh.plot()
mesh_solver = MeshSolver(tri_mesh, dt, g)
mesh_solver.init_height(ic)
grid = UniformGrid((-3.1,3.1,65), (-3.1,3.1,65),tilt)
grid.plot()
grid_solver = GridSolver(grid, dt, g)
grid_solver.init_height(ic)

fps = 25
speed_up = 1
ani = simulate_and_plot(mesh_solver, grid_solver, target_time, dt, fps, speed_up, norm_h, norm_v)
print("output animation...")
HTML(ani.to_jshtml())

Increasing resolution

In [None]:
def tilt(X, Y):
    a = (X + Y > -1) & (X + Y < 1)
    b = (X - Y < 5) & (X - Y > -5)
    mask = np.where(a & b, True, False)
    return mask

norm_h = Normalize(1, 1.5)
norm_v = Normalize(0, 0.5)
ic = lambda X, Y: np.where(X - Y < -4, 4, 1)
target_time = 8
dt = 1e-3
g = 10

tri_mesh = TriMesh(dir_path + 'tilt', '.ply2', refine=1)
tri_mesh.plot()
mesh_solver = MeshSolver(tri_mesh, dt, g)
mesh_solver.init_height(ic)
grid = UniformGrid((-3.1,3.1,130), (-3.1,3.1,130),tilt)
grid.plot()
grid_solver = GridSolver(grid, dt, g)
grid_solver.init_height(ic)

fps = 25
speed_up = 1
ani = simulate_and_plot(mesh_solver, grid_solver, target_time, dt, fps, speed_up, norm_h, norm_v)
print("output animation...")
HTML(ani.to_jshtml())

### Example3: Narrowing Channel

In [None]:
# !wget https://githubraw.com/ryan42210/FiniteVolumeOnMesh/main/assets/narrowing_channel.ply2

In [None]:
# narrowing channel
norm_h = Normalize(1, 3)
norm_v = Normalize(0, 1)
ic = lambda X, Y: np.where(X < -4.5, 5, 1)

target_time = 15
dt = 5e-3
g = 2

tri_mesh = TriMesh(dir_path + 'narrowing_channel', '.ply2')
tri_mesh.plot()

tri_solver = MeshSolver(tri_mesh, dt, g)
tri_solver.init_height(ic)

fps = 25
speed_up = 1
ani = simulate_and_plot(tri_solver, None, target_time, dt, fps, speed_up, norm_h, norm_v)
print("output animation...")
HTML(ani.to_jshtml())

### Example4: Breakwater

In [None]:
# !wget https://githubraw.com/ryan42210/FiniteVolumeOnMesh/main/assets/breakwater.ply2

In [None]:
# breakwater
norm_h = Normalize(1, 2)
norm_v = Normalize(0, 1)
ic = lambda X, Y: np.where(X < -3.5, 5, 1)

target_time = 10
dt = 5e-3
g = 10

tri_mesh = TriMesh(dir_path + 'breakwater', '.ply2')
tri_mesh.plot()

tri_solver = MeshSolver(tri_mesh, dt, g)
tri_solver.init_height(ic)

fps = 25
speed_up = 1
ani = simulate_and_plot(tri_solver, None, target_time, dt, fps, speed_up, norm_h, norm_v)
print("output animation...")
HTML(ani.to_jshtml())