#### Jupyter notebook example of a hexagonal stencil being used for finite difference method

An example of a finite differences method using a hexagonal mesh is given for the simple 2D heat equation for $u(t,x,y)$ defined on the domain $I \times \Omega = \{(t,x,y) : 0 < t,\; 0 < x < 1,\; 0 < y < 1\},$
$$\frac{\partial u}{\partial t} = \mu \left( \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2} \right), $$
with Dirichlet boundary conditions.

The function 
$$u(t,x,y) = \cos(2 \pi x) \cos(2 \pi y) e^{-8\mu t \pi^2},$$ 
which satisfies the partial differential equation, is used as a test problem below. The boundary conditions are chosen for consistency with the example.

This notebook requires numpy, scipy and matplotlib. 
If you do not have these already, consider trying `!pip install numpy scipy matplotlib` in a Jupyter code cell. Versions known to work are numpy (2.1.2), scipy (1.14.1), matplotlib (3.9.2).

In [None]:
%matplotlib ipympl

import numpy as np
import matplotlib.pyplot as plt
import scipy, scipy.sparse, scipy.sparse.linalg, matplotlib, IPython
from mpl_toolkits.mplot3d import Axes3D, art3d
from matplotlib.animation import FuncAnimation, PillowWriter

matplotlib.interactive(True) # Often needed for generating animations

In [None]:
N = 24 # grid subdivisions (both x and y equally), pick an even number greater 6 or greater
assert N%2 == 0 and N > 4, "Life is easier when you pick an even number greater than 4 for N"

In [None]:
# hex grid
hhat = 1.0 / N
off_x, short_x, long_x = 0.5*hhat, hhat, 2.0*hhat
off_y = hhat*np.sqrt(3.0) / 2.0

jdx = np.arange(N)
initoff = np.where(jdx % 2 == 1, 1, 0)
initoff.resize((N,1))

idx = np.arange(N-1)
idx, jdx = np.meshgrid(idx,jdx)
shortoff = np.where((idx+1) % 2 == jdx % 2, 1, 0)

hex_x = short_x * shortoff + long_x * (1-shortoff)
hex_x = np.hstack((initoff*off_x, hex_x)).cumsum(axis=1)

idx, jdx = np.arange(N), np.arange(N)
idx, jdx = np.meshgrid(idx,jdx)
hex_y = off_y * jdx

hex_y.resize(N*N, 1)
hex_x.resize(N*N, 1)

# Ensure md array is 1d array for np indicies 
hex_x = np.hstack(hex_x) / hex_x.max()
hex_y = np.hstack(hex_y) / hex_y.max()

In [None]:
# connectivity (and for plotting)
hex_poly_tmp = np.vstack((1+np.arange(N-1)[::2], 1+N+np.arange(N-1)[::2], 1+2*N+np.arange(N-1)[::2], 
                     2+2*N+np.arange(N-1)[::2], 2+N+np.arange(N-1)[::2], 2+np.arange(N-1)[::2],    )).T
hex_polys = np.vstack([np.vstack(( hex_poly_tmp[:-1] + 2*j*N,  hex_poly_tmp +  (2*j+1)*N - 1 ) ) for j in range( (N>>1) - 1 ) ])

hex_bcs_left   = np.arange(N)*N
hex_bcs_right  = np.arange(N)*N + N-1
hex_bcs_bottom = np.hstack((np.array([N+1]), np.hstack( np.vstack((np.arange(1,N-1)[::2], 1+np.arange(1,N-1)[::2], 1+N+np.arange(1,N-1)[::2], 2+N+np.arange(1,N-1)[::2])).T)))[:-1]
hex_bcs_top    = np.hstack(np.vstack( (N*(N-1)+ np.arange(1,N-1)[::2], -N + N*(N-1)+ np.arange(1,N-1)[::2], 1 -N + N*(N-1)+ np.arange(1,N-1)[::2], 1 + N*(N-1)+ np.arange(1,N-1)[::2] ) ).T)
hex_bcs_all    = np.hstack((hex_bcs_bottom,hex_bcs_right,hex_bcs_top[::-1], hex_bcs_left[::-1]))

In [None]:
# used to generate the two reflective stencils used in the hexagonal mesh
hex_tmp_r = np.vstack(( 1+N+np.arange(N-1)[::2], 1+2*N+np.arange(N-1)[::2], 
                     2+2*N+np.arange(N-1)[::2], 3*N+1+np.arange(N-1)[::2],  )).T
hex_stencil_r = np.vstack([np.vstack(( hex_tmp_r[:-1] + 2*j*N,  hex_tmp_r[1:] +  (2*j+1)*N - 1 ) ) for j in range( (N>>1) - 2 ) ])

hex_tmp_l = np.vstack(( N+np.arange(N-1)[::2], -1+2*N+np.arange(N-1)[::2], 
                     2*N+np.arange(N-1)[::2], 3*N+np.arange(N-1)[::2] ,  )).T
hex_stencil_l = np.vstack([np.vstack(( hex_tmp_l[1:] + 2*j*N,  hex_tmp_l[1:] +  (2*j+1)*N - 1 ) ) for j in range( (N>>1) - 2 ) ])

hex_centr = np.hstack((hex_stencil_r[:,2],hex_stencil_l[:,1]))

In [None]:
# Rectangular grid
rect_x, rect_y = np.linspace(0,1,N), np.linspace(0,1,N)
sX,sY = np.meshgrid(rect_x,rect_y)
rect_x, rect_y = np.hstack(sX), np.hstack(sY)

idx, jdx   = np.arange(1,N-1), np.arange(1,N-1)
idx, jdx   = np.meshgrid(idx,jdx)
rect_centr = np.hstack(jdx*N + idx)
rect_centr.resize((N-2)*(N-2),1)
rect_stencil = np.hstack((rect_centr-N, rect_centr-1, rect_centr, rect_centr+1, rect_centr+N))
rect_centr = np.hstack(rect_centr)

rect_bcs_left   = np.arange(1,N-1)*N
rect_bcs_right  = np.arange(1,N-1)*N + N-1
rect_bcs_bottom = np.arange(N)
rect_bcs_top    = (N-1)*N + np.arange(N)
rect_bcs_all    = np.hstack((rect_bcs_bottom,rect_bcs_right,rect_bcs_top[::-1],rect_bcs_left[::-1]))

# this section is not needed for calculation (but is nice for plotting later on)
idx = np.arange(N-1)
jdx = np.arange(N-1)
idx, jdx = np.meshgrid(idx,jdx)
rect_polys = np.vstack(np.array((jdx*N + idx, jdx*N + idx+1, (jdx+1)*N +idx+1, (jdx+1)*N + idx)).T)

# Ensure md array is 1d array for np indicies 
rect_x = np.hstack(rect_x)
rect_y = np.hstack(rect_y)

In [None]:
N2 = N*N
h = 1.0 / N

# Simulation paramters
dt = 5e-2
mu = 1e-2

# Initial and boundary condition
#exact_u = lambda t,x,y:  2*np.cos(np.pi*4*x)*np.exp(-mu*16*np.pi**2*t) # example 1
exact_u = lambda t,x,y:  np.cos(np.pi*2*x) *np.cos(np.pi*2*y) * np.exp(-8*mu*t*np.pi**2)  # example 2

# Dirichlet BCs
rect_rows  = np.arange(len(rect_bcs_all)) 
rect_I_bc  = scipy.sparse.csr_matrix((1 +0*rect_rows,(rect_rows, rect_bcs_all)), shape=(N2,N2) ) 

hex_rows  = np.arange(len(hex_bcs_all)) 
hex_I_bc  = scipy.sparse.csr_matrix((1 +0*hex_rows,(hex_rows, hex_bcs_all)), shape=(N2,N2) ) 

# Effective Mass + Stiffness matricies
rect_rows_bulk_start_idx = 1 + rect_rows[-1]
rect_rows = rect_rows_bulk_start_idx + np.repeat(np.arange(rect_centr.shape[0]), 5)
rect_cols = np.hstack(rect_stencil)
rect_elmt = np.tile(np.array([mu/h**2, mu/h**2, -1/dt -mu*4/h**2, mu/h**2, mu/h**2]), (N-2)*(N-2))
rect_A = scipy.sparse.csr_matrix( (rect_elmt, (rect_rows, rect_cols)), shape=(N2,N2) )

hex_rows_bulk_start_idx = 1 + hex_rows[-1]
hex_rows = hex_rows_bulk_start_idx + np.repeat(np.arange(hex_stencil_r.shape[0]), 4)
hex_cols_r = np.hstack(hex_stencil_r)
hex_elmt_r = np.tile(np.array([mu*4/(3*h**2), -mu*4/(h**2) -1/dt, mu*4/(3*h**2),  mu*4/(3*h**2)]), hex_stencil_r.shape[0])
hex_A_r = scipy.sparse.csr_matrix( (hex_elmt_r, (hex_rows, hex_cols_r)), shape=(N2,N2) )

hex_rows = 1 + hex_rows[-1] + np.repeat(np.arange(hex_stencil_l.shape[0]), 4)
hex_cols_l = np.hstack(hex_stencil_l)
hex_elmt_l = np.tile(np.array([mu*4/(3*h**2), mu*4/(3*h**2), -mu*4/(h**2) -1/dt,  mu*4/(3*h**2)]), hex_stencil_l.shape[0])
hex_A_l = scipy.sparse.csr_matrix( (hex_elmt_l, (hex_rows, hex_cols_l)), shape=(N2,N2) )

# Right hand side
rect_rhs  = np.zeros(N2)
rect_rhs[np.arange(len(rect_bcs_all))] = exact_u(0, rect_x[rect_bcs_all], rect_y[rect_bcs_all]) 
rect_bulk_rhs_rows  = rect_rows_bulk_start_idx + np.arange(rect_centr.shape[0])
#rect_rhs[rect_bulk_rhs_rows] = -rect_u[ rect_centr ]/dt -- here for reference but not used until stepping+plotting

hex_rhs = np.zeros(N2)
hex_rhs[np.arange(len(hex_bcs_all))] = exact_u(0, hex_x[hex_bcs_all], hex_y[hex_bcs_all]) 
hex_bc_rows  = np.arange(len(hex_bcs_all)) 
hex_bulk_rhs_rows  = hex_rows_bulk_start_idx + np.arange(hex_centr.shape[0])
#hex_rhs[hex_bulk_rhs_rows] = -hex_u[ hex_centr ]/dt -- here for reference but not used until stepping+plotting

# Precomputed full left hand system
rect_M = rect_I_bc + rect_A
hex_M  = hex_I_bc + hex_A_r + hex_A_l


### The hard part is over, what follows below is plotting code

In [None]:
# Dark mode
plt.rcParams['axes.facecolor']  = '#000'
plt.rcParams['axes.titley']     = -0.05
plt.rcParams['axes.labelcolor'] = '#fff'
plt.rcParams['axes.titlecolor'] = '#fff'

In [None]:
time = 0
def drawstep(ax1,ax2,ax3):
    global hex_u, rect_u
    global rect_M, rect_rhs
    global hex_M, hex_rhs
    global time, dt

    # Draw
    plt.tight_layout(pad=-1)
    
    tt =  np.column_stack((rect_x,rect_y,rect_u))
    ax1.add_collection3d(art3d.Poly3DCollection(tt[rect_polys], facecolors='#333', alpha=1.0, edgecolors='white', linewidth=0.5))
    ax1.set_title('Rectangular mesh')
    ax1.view_init(65,200) 
    ax1.axis('off')
    ax1.xaxis.pane.fill, ax1.yaxis.pane.fill, ax1.zaxis.pane.fill = False, False, False
    ax1.set_zlim([-1,1])
    ax1.grid(False)
    
    tt =  np.column_stack((hex_x,hex_y,hex_u))
    ax2.add_collection3d(art3d.Poly3DCollection(tt[hex_polys], facecolors='#333', alpha=1.0, edgecolors='white', linewidth=0.5))
    ax2.set_title('Hexagonal mesh')
    ax2.view_init(65,200) 
    ax2.axis('off')
    ax2.xaxis.pane.fill, ax2.yaxis.pane.fill, ax2.zaxis.pane.fill = False, False, False
    ax2.set_zlim([-1,1])
    ax2.grid(False)

    tt =  np.column_stack((hex_x,hex_y, exact_u(time, hex_x,hex_y) ))
    ax3.add_collection3d(art3d.Poly3DCollection(tt[hex_polys], facecolors='#333', alpha=1.0, edgecolors='white', linewidth=0.5))
    ax3.set_title('Exact solution')
    ax3.view_init(65,200) 
    ax3.axis('off')
    ax3.xaxis.pane.fill, ax2.yaxis.pane.fill, ax2.zaxis.pane.fill = False, False, False
    ax3.set_zlim([-1,1])
    ax3.grid(False)

    # Step
    rect_rhs[np.arange(len(rect_bcs_all))] = exact_u(time, rect_x[rect_bcs_all], rect_y[rect_bcs_all]) 
    hex_rhs[np.arange(len(hex_bcs_all))] = exact_u(time, hex_x[hex_bcs_all], hex_y[hex_bcs_all])

    rect_rhs[rect_bulk_rhs_rows] = -rect_u[rect_centr]/dt
    hex_rhs[hex_bulk_rhs_rows] = -hex_u[hex_centr]/dt

    rect_u = scipy.sparse.linalg.spsolve(rect_M, rect_rhs)
    hex_u  = scipy.sparse.linalg.spsolve(hex_M, hex_rhs)

    time += dt


In [None]:
# Animation (may take a while to generate)
time = 0
hex_u  = exact_u(time, hex_x, hex_y) 
rect_u = exact_u(time, rect_x, rect_y) 

fig = plt.figure(figsize=(10,5), dpi=150)

plt.rcParams['axes.facecolor'] = '#000'
fig.patch.set_facecolor('#000')
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132, projection='3d')
ax3 = fig.add_subplot(133, projection='3d')

def animate(frame):
    ax1, ax2, ax3 = globals()['ax1'], globals()['ax2'], globals()['ax3']
    ax1.clear(), ax2.clear(), ax3.clear()
    drawstep(ax1,ax2,ax3)
    return [ax1,ax2,ax3]

gifname='hex_fdm_example.gif'
ani = FuncAnimation(fig, animate, frames=np.arange(50), init_func=None, blit=False, interval=100, repeat=False)
ani.save(gifname, writer=PillowWriter(fps=8))

Your browser may cache the gif file so you may want to rename the file to get a new rendering to load right away

In [3]:
IPython.display.HTML(f'<td><img src="{gifname}" width="100%" /></td>')