Simulation of a 2-D Spring lattice  
$F_{(i,j),(k,l)} = \frac{k\Delta z(r - r_0)}{r}$, $\Delta z = z_{(i,j)} - z_{(k,l)}$, $r = \sqrt{\Delta z^2 + d^2}$

$m_{(i,j)} \ddot z_{(i,j)} = -F_{(i-1,j),(i,j)} - F_{(i,j-1),(i,j)} + F_{(i,j),(i+1,j)} + F_{(i,j),(i,j+1)}$  
$z_{(0,:)} = 0$, $z_{(:,0)} = 0$

### To-do:
1. Explore other boundary conditions
2. FFT to determine harmonics of many-mass system
3. State-space of trajectories
4. For many more masses, heat-map rather than manim animation to visualize wave propagation
5. Widget for parameter variation (for heatmap, animations take too long)
6. README
7. Chaos??
8. Examine linear approximations (pre-tensioned versus no tension)

In [7]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from manim import *
from dataclasses import dataclass

In [51]:
@dataclass
class Params():
    N: int = 7          # Number of masses
    m: float = 0.2      # Mass of each mass (kg)
    k: float = 10.     # Spring constant (N/m)
    r_0: float = 0.1    # Spring equilibrium distance (m)
    d: float = 0.5      # In-plane distance between adjacent masses (m)
    t_max: float = 10.0     # Total time to simulate (s)
    dt: float = 0.01    # Time step (s)

class Drumhead():
    
    """
    N = number of masses in the system
    m = mass of each mass
    r_0 = spring equilibrium distance
    d = in-plane distance between adjacent masses
    """
    def __init__(self, p: Params):
        self.N = p.N
        self.m = p.m
        self.k = p.k
        self.r_0 = p.r_0
        self.d = p.d
    
    def dy_dt(self, t, z_arr):

        z = z_arr[:self.N**2].reshape((self.N, self.N))
        z_dot = z_arr[self.N**2:]

        z_ddot = np.zeros((self.N, self.N))
        
        #Boundary conditions: fixed ends
        z_ddot[0, :] = 0
        z_ddot[-1, :] = 0
        z_ddot[:, 0] = 0
        z_ddot[:, -1] = 0

        for i in range(1, self.N - 1):
            for j in range(1, self.N - 1):
                #z_ddot[i] = (self.k / self.m) * (z[i+1] - 2 * z[i] + z[i-1])
                r_up = np.sqrt(self.d**2 + (z[i-1, j] - z[i, j])**2)
                r_down = np.sqrt(self.d**2 + (z[i+1, j] - z[i, j])**2)
                r_left = np.sqrt(self.d**2 + (z[i, j-1] - z[i, j])**2)
                r_right = np.sqrt(self.d**2 + (z[i, j+1] - z[i, j])**2)

                F_up = self.k * (r_up - self.r_0) * (z[i-1, j] - z[i, j]) / r_up
                F_down = self.k * (r_down - self.r_0) * (z[i+1, j] - z[i, j]) / r_down
                F_left = self.k * (r_left - self.r_0) * (z[i, j-1] - z[i, j]) / r_left
                F_right = self.k * (r_right - self.r_0) * (z[i, j+1] - z[i, j]) / r_right

                z_ddot[i, j] = (F_up + F_down + F_left + F_right) / self.m

        z_ddot = z_ddot.flatten()

        return np.concatenate((z_dot, z_ddot))
    
    def solve_ode(self, t_pts, z_0, z_dot_0, 
                  abserr=1.0e-10, relerr=1.0e-10):
        """
        Solve the ODE given initial conditions.
        For now use odeint, but we have the option to switch.
        Specify smaller abserr and relerr to get more precision.
        """
        z_0 = z_0.flatten()
        z_dot_0 = z_dot_0.flatten()

        z_arr = np.concatenate((z_0, z_dot_0)) 
        solution = solve_ivp(self.dy_dt, (t_pts[0], t_pts[-1]), 
                             z_arr, t_eval=t_pts, 
                             atol=abserr, rtol=relerr)
        return solution.y

In [52]:
# Create String instance
params = Params()

def solve_system(params):
    drumhead = Drumhead(params)

    t_pts = np.arange(0, params.t_max, params.dt)
    z_0 = np.zeros((params.N, params.N))
    z_dot_0 = np.zeros((params.N, params.N))
    z_0[params.N // 2, params.N // 2] = 1  # Initial displacement at the center mass
    solution = drumhead.solve_ode(t_pts, z_0, z_dot_0)
    return t_pts, solution

t_pts, sol = solve_system(params)
z = sol[:params.N**2, :]
z = z.reshape((params.N, params.N, -1))

# Plot results
'''
fig, ax = plt.subplots(params.N**2, 1, figsize=(10, 2*params.N**2))

for i in range(params.N):
    for j in range(params.N):
        ind = i * params.N + j
        ax[ind].plot(t_pts, z[i][j][:])
        ax[ind].set_title(f'Mass {i+1} Displacement Over Time')
        ax[ind].set_xlabel('Time')
        ax[ind].set_ylabel('Displacement')
'''

"\nfig, ax = plt.subplots(params.N**2, 1, figsize=(10, 2*params.N**2))\n\nfor i in range(params.N):\n    for j in range(params.N):\n        ind = i * params.N + j\n        ax[ind].plot(t_pts, z[i][j][:])\n        ax[ind].set_title(f'Mass {i+1} Displacement Over Time')\n        ax[ind].set_xlabel('Time')\n        ax[ind].set_ylabel('Displacement')\n"

In [53]:
# -------------------------------
# Simple spring helper (zig-zag)
# -------------------------------
def spring_polyline(start, end, coils=6, amplitude=0.25, inset=0.35):
    """
    Returns a VMobject shaped like a planar coil spring from start -> end.
    Uses set_points_as_corners for a crisp zig-zag. Compatible with manim v0.19.
    """
    start = np.array(start, dtype=float)
    end   = np.array(end, dtype=float)
    vec = end - start
    L = np.linalg.norm(vec)
    if L < 1e-6:
        return Line(start, end, stroke_width=6)

    # Local frame
    xhat = vec / L
    up = np.array([0.0, 1.0, 0.0])
    yhat = up - np.dot(up, xhat) * xhat
    ny = np.linalg.norm(yhat)
    if ny < 1e-8:
        right = np.array([1.0, 0.0, 0.0])
        yhat = right - np.dot(right, xhat) * xhat
        yhat /= np.linalg.norm(yhat)
    else:
        yhat /= ny

    # Straight end segments + zig-zag body
    Lz = max(L - 2 * inset, 0.0)
    n_verts = 2 * coils + 1
    xs = np.linspace(inset, inset + Lz, n_verts)

    ys = np.zeros_like(xs)
    ys[1::2] =  amplitude
    ys[2::2] = -amplitude
    # Ensure the last zig-zag point is on the center line
    if n_verts > 0:
        ys[-1] = 0

    pts = [start, start + xhat * inset]
    for xi, yi in zip(xs, ys):
        pts.append(start + xhat * xi + yhat * yi)
    pts += [end - xhat * inset, end]

    pts = np.array(pts, dtype=float)

    spring = VMobject()
    spring.set_points_as_corners(pts)
    spring.set_stroke(width=6)
    spring.set_fill(opacity=0)
    return spring

In [54]:
class Drumhead3D(ThreeDScene):
    def construct(self):
        # Use the already-computed solution arrays (params, solve_system, t_pts, sol)
        params_local = params
        t_pts, sol = solve_system(params_local)
        z = sol[: params_local.N**2, :].reshape((params_local.N, params_local.N, -1))

        def z_of(t, i, j):
            return np.interp(t, t_pts, z[i, j, :])

        # Layout grid in X (columns) and Z (rows); vertical displacement is Y
        span = 6.0
        x_range = np.linspace(-span / 2, span / 2, params_local.N)
        y_range = np.linspace(-span / 2, span / 2, params_local.N)

        # Time tracker drives the animation
        t_tracker = ValueTracker(0.0)

        # Camera and axes for 3D view
        self.set_camera_orientation(phi=65 * DEGREES, theta=-45 * DEGREES)
        #axes = ThreeDAxes(x_length=span, y_length=4.0, z_length=span)
        #self.add(axes)

        # Create spheres at each lattice site and update their Y position per-frame
        spheres = [[None for _ in range(params_local.N)] for __ in range(params_local.N)]

        for i in range(params_local.N):
            for j in range(params_local.N):
                x = x_range[i]
                y = y_range[j]
                z_pos = z_of(0.0, i, j)
                radius = min(0.12, span / (3.0 * params_local.N))
                s = Sphere(radius=radius).move_to([x, y, z_pos]).set_color(BLUE).set_shade_in_3d(True)

                def sph_updater(mob, i=i, j=j, x=x, y=y):
                    z_pos = z_of(t_tracker.get_value(), i, j)
                    mob.move_to([x, y, z_pos])

                s.add_updater(sph_updater)
                spheres[i][j] = s
                self.add(s)

        # Create springs connecting right and down neighbors (avoid duplicates)
        springs = []
        for i in range(params_local.N):
            for j in range(params_local.N):
                if j + 1 < params_local.N:
                    left = spheres[i][j]
                    right = spheres[i][j + 1]
                    start = left.get_center()
                    end = right.get_center()
                    spr = spring_polyline(start, end, coils=4, amplitude=radius * 0.5, inset=radius * 0.5).set_color(WHITE)

                    def spr_update(s, a=left, b=right):
                        new = spring_polyline(a.get_center(), b.get_center(), coils=4, amplitude=radius * 0.9, inset=radius * 0.5).set_color(WHITE)
                        s.become(new)

                    spr.add_updater(spr_update)
                    springs.append(spr)
                    self.add(spr)

                if i + 1 < params_local.N:
                    top = spheres[i][j]
                    bot = spheres[i + 1][j]
                    start = top.get_center()
                    end = bot.get_center()
                    spr2 = spring_polyline(start, end, coils=4, amplitude=radius * 0.9, inset=radius * 0.5).set_color(WHITE)

                    def spr2_update(s, a=top, b=bot):
                        new = spring_polyline(a.get_center(), b.get_center(), coils=4, amplitude=radius * 0.9, inset=radius * 0.5).set_color(WHITE)
                        s.become(new)

                    spr2.add_updater(spr2_update)
                    springs.append(spr2)
                    self.add(spr2)

        # Optional time readout
        time_readout = DecimalNumber(number=0.0, num_decimal_places=2, include_sign=False).set_font_size(24).to_corner(UR).shift(LEFT * 1.1 + DOWN * 1.2)
        time_label = Text("t (s) =", font_size=24).next_to(time_readout, LEFT, buff=0.2)

        def time_updater(mob):
            mob.set_value(t_tracker.get_value())

        time_readout.add_updater(time_updater)
        self.add(time_label, time_readout)

        # Animate: advance the tracker from 0 -> T_total (real-time run)
        self.play(t_tracker.animate.set_value(params_local.t_max), run_time=params_local.t_max, rate_func=linear)
        self.wait(0.5)

In [55]:
%manim -pql Drumhead3D

  time_readout = DecimalNumber(number=0.0, num_decimal_places=2, include_sign=False).set_font_size(24).to_corner(UR).shift(LEFT * 1.1 + DOWN * 1.2)
                                                                                              