bouncing_ball_penalty.ipynb
------------------------
Bouncing ball in 1D (vertical) using a penalty method for contact with the ground
and a central difference (explicit) time integrator.

- Ball: mass m, radius r
- Ground: horizontal plane at u = 0
- Contact: penalty spring (stiffness k_n) and optional dashpot (c_n)
- Gravity acts downward (negative y direction)

The scheme is the classical central difference method:
    u_{n+1} = 2 u_n - u_{n-1} + dt^2 * a_n
with a_n computed from forces at step n. Velocities are recovered as
    v_{n} ≈ (u_{n+1} - u_{n-1}) / (2 dt).

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
import numpy as np
import plotly.graph_objects as go

In [None]:
@dataclass
class Ball:
    """Rigid ball moving in 1D vertical direction."""
    m: float           # mass
    r: float           # radius
    u0: float          # initial center height
    v0: float = 0.0    # initial vertical velocity (up is +)
    a0: float = 0.0    # initial vertical acceleration

    def __post_init__(self):
        if self.m <= 0:
            raise ValueError("Mass must be positive.")
        if self.r < 0:
            raise ValueError("Radius must be positive or zero.")


@dataclass
class PenaltyContact:
    """Penalty contact with the ground y = 0."""
    k_n: float                 # normal contact stiffness [N/m]
    c_n: float = 0.0           # contact damping [N·s/m] (dashpot), optional
    u_ground: float = 0.0      # ground height

    def __post_init__(self):
        if self.k_n < 0:
            raise ValueError("Penalty stiffness k_n must be positive or zero.")
        if self.c_n < 0:
            raise ValueError("Damping c_n must be nonnegative.")


@dataclass
class SimParams:
    g: float                   # gravitational acceleration [m/s^2] (positive)
    dt: float                  # time step [s]
    t_end: float               # total time [s]
    record_energy: bool = True # store energies

In [None]:
def get_penetration(u: float, r: float, y_ground: float) -> float:
    """Returns penetration (>=0) at the contact."""
    return max(0.0, y_ground + r - u)

def contact_force(penetration: float, v_n: float, k_n: float, c_n: float) -> float:
    """
    Linear penalty spring + dashpot acting upward when in contact.
    F_c = k_n * penetration - c_n * v_n, applied only if penetration > 0.
    If penetration == 0, the contact force is zero (no tension).
    """
    if penetration <= 0.0:
        return 0.0
    return k_n * penetration - c_n * v_n

def get_timestep(ball: Ball, contact: PenaltyContact, safety: float = 0.2) -> float:
    """
    A critical time step estimate based on the contact stiffness:
    omega = sqrt(k_n / m)  =>  dt_crit ≈ 2 / omega
    dt = safety * dt_crit
    This is conservative during contact; out of contact the system is free fall.
    """
    omega = np.sqrt(contact.k_n / ball.m)
    dt_crit = 2.0 / omega
    return safety * dt_crit


def run(ball: Ball, contact: PenaltyContact, sim: SimParams) -> Dict[str, np.ndarray]:
    """
    Simulate the bouncing ball using the penalty method and central differences.
    Returns a dict with time histories.
    """
    # Time grid
    n_steps = int(np.ceil(sim.t_end / sim.dt)) + 1
    t = np.linspace(0.0, sim.dt * (n_steps - 1), n_steps)

    # Allocate arrays
    u = np.zeros(n_steps)  # position of ball center
    v = np.zeros(n_steps)  # velocity (for output/diagnostics)
    a = np.zeros(n_steps)  # acceleration (for output/diagnostics)
    f_c = np.zeros(n_steps)  # contact force
    pen = np.zeros(n_steps)  # penetration

    # Energies (optional)
    K = np.zeros(n_steps)  # kinetic
    Ug = np.zeros(n_steps) # gravitational potential (set 0 at u=0)
    Uc = np.zeros(n_steps) # contact spring energy
    Dd = np.zeros(n_steps) # damping dissipation (accumulated)

    # Initial conditions at n=0
    u[0] = ball.u0
    v[0] = ball.v0
    a[0] = -sim.g  # free fall initially

    if sim.record_energy:
        K[0] = 0.5 * ball.m * v[0] * v[0]
        Ug[0] = ball.m * sim.g * u[0]
        Uc[0] = 0.0
        Dd[0] = 0.0

    # Main loop
    for n in range(0, n_steps-1):

        # Predictor 
        u_pred = u[n]  + sim.dt * v[n] + 0.5 * (sim.dt ** 2) * a[n]  
        v_pred = v[n] + sim.dt * a[n] 

        # Contact at predicted step
        penetration = get_penetration(u_pred, ball.r, contact.u_ground)
        fc = contact_force(penetration, v_pred, contact.k_n, contact.c_n)

        u[n+1] = u_pred
        a[n+1] = -sim.g + fc / ball.m   
        v[n+1] = v_pred + 0.5 * sim.dt * (a[n+1] - a[n]) 

        f_c[n+1] = fc  # for info only
        pen[n+1] = penetration  # for info only

        # Energy accounting
        if sim.record_energy:
            K[n+1] = 0.5 * ball.m * v[n+1] * v[n+1]
            Ug[n+1] = ball.m * sim.g * u[n+1]
            Uc[n+1] = 0.5 * contact.k_n * penetration * penetration
            # Incremental Rayleigh-like damping dissipation in contact
            if penetration > 0.0 and contact.c_n > 0.0:
                Dd[n+1] = Dd[n] + contact.c_n * (v_pred ** 2) * sim.dt
            else:
                Dd[n+1] = Dd[n]


    out = {
        "t": t,
        "u": u,
        "v": v,
        "a": a,
        "contact_force": f_c,
        "penetration": pen,
    }
    if sim.record_energy:
        out.update({
            "K": K,
            "Ug": Ug,
            "Uc": Uc,
            "D_damp": Dd,
            "E_mech": K + Ug + Uc,
            "E_algo": K + Ug + Uc - sim.dt**2 / 8 * ball.m * a**2,  # includes algorithmic energy
            "E_total_mech": K + Ug + Uc + Dd,  # decreases with damping
            "E_total_algo": K + Ug + Uc + Dd - sim.dt**2 / 8 * ball.m * a**2,
            
        })
    return out

def plot_state(time: np.ndarray, y: np.ndarray, pen: np.ndarray) -> None:
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=time, y=y, mode="lines", name="y (center)"))
    fig.add_trace(go.Scatter(x=time, y=pen, mode="lines", name="penetration"))
    fig.update_layout(
        title="Bouncing ball (penalty contact)",
        xaxis_title="time [s]",
        yaxis_title="length [m]",
        legend=dict(x=0.02, y=0.98),
        template="plotly_white",
    )
    fig.show()

def plot_energy(time: np.ndarray, energies: dict) -> None:
    fig = go.Figure()
    for name, data in energies.items():
        fig.add_trace(go.Scatter(x=time, y=data, mode="lines", name=name))
    fig.update_layout(
        title="Energies",
        xaxis_title="time [s]",
        yaxis_title="energy [J]",
        legend=dict(x=0.02, y=0.98),
        template="plotly_white",
    )
    fig.show()


In [None]:

# Example parameters (SI units)
ball = Ball(m=0.5, r=0.0, u0=0.50, v0=0.0)
contact = PenaltyContact(k_n=1e6, c_n=0.0, u_ground=0.0)
dt = get_timestep(ball, contact, safety=0.4)
sim = SimParams(g=9.81, dt=dt, t_end=10, record_energy=True)

print(f"Using dt = {sim.dt:.6e} s")
result = run(ball, contact, sim)

umax = result["u"].max()
umin = result["u"].min()
n_impacts_steps = int(np.count_nonzero(result["penetration"] > 0.0))
print(f"Peak center height: {umax:.4f} m, min center height: {umin:.4f} m")
print(f"Time steps in contact: {n_impacts_steps} / {len(result['t'])}")

plot_state(result["t"], result["u"], result["penetration"])
energies = {
    "Kinetic": result["K"],
    "Gravitational potential": result["Ug"],
    "Contact potential": result["Uc"],
    "Damping": result["D_damp"],
}
plot_energy(result["t"], energies)
energies = {
    "E_total_mech": result["E_total_mech"] - result["E_total_mech"][0], 
    "E_total_algo": result["E_total_algo"] - result["E_total_algo"][0] 
}
plot_energy(result["t"], energies)