bouncing_ball_penalty.ipynb
------------------------
Bouncing ball in 1D (vertical) using a penalty method for contact with the ground
and a central difference (explicit) time integrator.

- Ball: mass m, radius r
- Ground: horizontal plane at u = 0
- Contact: penalty spring (stiffness k_n) and optional dashpot (c_n)
- Gravity acts downward (negative y direction)

The scheme is the classical central difference method:

$
\begin{cases}
    ma_{n+1} = -mg - f^\text{c}_{n+1} \quad \text{ with }
    f^\text{c}_{n+1}=\begin{cases}
        ku^*_{n+1} + cv^*_{n+1} &\text{if } u_{n+1} < 0\\
        0 & \text{otherwise}
    \end{cases}\\
    u^*_{n+1} = u_n+\Delta t v_n+\frac{1}{2}\Delta t^2 a_n\\
    v^*_{n+1} = v_n + \Delta t a_n\\
    u_{n+1} = u^*_{n+1}\\
    v_{n+1} = v^*_{n+1} + \frac{1}{2}(a_{n+1}-a_{n})
\end{cases}
$

Terms with the asterisk $^*$ are predictions

How to solve:
- predict displacement and velocity $u^*_{n+1}$, $v^*_{n+1}$
- compute contact force $f^\text{c}_{n+1}$
- update acceleration $a_{n+1} = -g - f^\text{c}_{n+1}/m$ 
- update displacement and velocity $u_{n+1}$, $v_{n+1}$

In [None]:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import plotly.graph_objects as go

In [None]:
@dataclass
class Params:
    m: float = 1.0              # mass [kg]
    r: float = 0.1              # radius [m]
    g_const: float = 9.81       # gravity constant [m/s^2]
    k: float = 1e4              # normal contact stiffness [N/m]
    c: float = 0                # contact damping [N·s/m]
    dt: float = 1e-3            # time step [s]
    t_end: float = 2.0          # end time [s]

@dataclass
class State:
    u: float                     # gap (distance between ball bottom and ground) [m]
    v: float                     # velocity [m/s]
    a: float                     # acceleration [m/s^2]
    fc: float                    # contact force [N]
    t: float                     # time [s]

@dataclass
class History:
    # kinematics + contact flags
    t: List[float] = field(default_factory=list)
    u: List[float] = field(default_factory=list)
    v: List[float] = field(default_factory=list)
    a: List[float] = field(default_factory=list)
    fc: List[float] = field(default_factory=list)
    contact: List[bool] = field(default_factory=list)
    # energies
    ekin: List[float] = field(default_factory=list)
    epotg: List[float] = field(default_factory=list)
    epotk: List[float] = field(default_factory=list)
    emech: List[float] = field(default_factory=list)
    edis: List[float] = field(default_factory=list)
    ealg: List[float] = field(default_factory=list)
    etot: List[float] = field(default_factory=list)
    etot_alg: List[float] = field(default_factory=list)

    def append(self, s: State, is_contact: bool,
               ekin: float, epotg: float, epotk: float,
               emech: float, edis: float, ealg: float, 
               etot: float, etot_alg: float) -> None:
        # state
        self.t.append(s.t)
        self.u.append(s.u)
        self.v.append(s.v)
        self.a.append(s.a)
        self.fc.append(s.fc)
        self.contact.append(is_contact)
        # energies
        self.ekin.append(ekin)
        self.epotg.append(epotg)
        self.epotk.append(epotk)
        self.emech.append(emech)
        self.edis.append(edis)
        self.ealg.append(ealg)
        self.etot.append(etot)
        self.etot_alg.append(etot_alg)

    def as_arrays(self) -> Dict[str, np.ndarray]:
        return {k: np.asarray(v) for k, v in self.__dict__.items()}


In [None]:
class BouncingBallPenalty:
    """
    Time integrator implementing contact via a penalty-based approach.
    It uses the central difference integration scheme.
    """
    def __init__(self, params: Params, u0: float, v0: float):
        self.p = params
        # recommended dt based on contact stiffness
        dt_rec = self.recommend_dt(params)
        if self.p.dt > dt_rec:
            print(f"Warning: time step {self.p.dt:.2e} exceeds recommended {dt_rec:.2e}")
        else:
            print(f"Info: time step {self.p.dt:.2e} within recommended {dt_rec:.2e} |"
                  f" Safety factor: {self.p.dt/dt_rec:.2f}")

        a0 = -self.p.g_const  # upward positive
        self.state = State(u=u0, v=v0, a=a0, fc=0.0, t=0.0)
        self.hist = History()
        # initial energies at n = 0
        ekin0 = 0.5 * self.p.m * v0**2
        epotg0 = self.p.m * self.p.g_const * u0
        epotk0 = 0.0
        emech0 = ekin0 + epotg0 + epotk0
        edis0 = 0.0
        ealg0 = emech0 - (self.p.dt**2) * 0.125 * self.p.m * (a0**2)
        etot0 = emech0 + edis0
        etot_alg0 = ealg0 + edis0
        self.hist.append(self.state, is_contact=False,
                         ekin=ekin0, epotg=epotg0, epotk=epotk0, 
                         emech=emech0, edis=edis0, ealg=ealg0, 
                         etot=etot0, etot_alg=etot_alg0)
    
    def recommend_dt(self, params: Params, safety: float = 1.0) -> float:
        """
        A critical time step estimate based on the contact stiffness:
        omega = sqrt(k / m)  =>  dt_crit ≈ 2 / omega
        dt = safety * dt_crit
        This is conservative during contact; out of contact the system is free fall.
        """
        omega = np.sqrt(params.k / params.m)
        dt_crit = 2.0 / omega
        return safety * dt_crit

    def step(self) -> Tuple[State, bool]:
        s = self.state
        p = self.p

        # predict displacement and velocity at n+1
        u_pred = s.u + p.dt * s.v + 0.5 * p.dt**2 * s.a
        v_pred = s.v + p.dt * s.a

        # contact detection
        is_contact = u_pred <= 0.0

        # contact force (penalty)
        fc = -p.k * u_pred - p.c * v_pred if is_contact else 0.0

        # acceleration at n+1
        a_new = (-p.g_const + fc / p.m)

        # correct velocity at n+1
        v_new = s.v + 0.5 * p.dt * (s.a + a_new)

        # update state to n+1
        s_new = State(u=u_pred, v=v_new, a=a_new, fc=fc, t=s.t + p.dt)
        self.state = s_new

        # energies at n+1
        ekin = 0.5 * p.m * v_new**2
        epotg = p.m * p.g_const * u_pred
        epotk = 0.5 * p.k * u_pred**2 if is_contact else 0.0
        emech = ekin + epotg + epotk
        edis = self.hist.edis[-1] + p.c * (v_pred**2) * p.dt if is_contact else self.hist.edis[-1]
        ealg = emech - (p.dt**2) * 0.125 * p.m * (a_new**2)
        etot = emech + edis
        etot_alg = ealg + edis

        # store history
        self.hist.append(s_new, is_contact,
                         ekin=ekin, epotg=epotg, epotk=epotk, emech=emech,
                         edis=edis, ealg=ealg, etot=etot, etot_alg=etot_alg)

        return s_new, is_contact

    def run(self) -> History:
        n_steps = int(np.ceil(self.p.t_end / self.p.dt))
        for _ in range(n_steps):
            self.step()
        return self.hist


def plot_history(hist: History) -> None:
    H = hist.as_arrays()
    t = H["t"]

    # u(t) with contact markers
    fig_u = go.Figure()
    fig_u.add_trace(go.Scatter(x=t, y=H["u"], mode="lines", name="u (gap) [m]"))
    mask = H["contact"].astype(bool)
    if mask.any():
        fig_u.add_trace(go.Scatter(x=t[mask], y=H["u"][mask], mode="markers",
                                   name="contact", marker=dict(symbol="circle-open")))
    fig_u.update_layout(title="Bouncing ball — gap to ground", xaxis_title="t [s]", yaxis_title="u [m]",
                        xaxis=dict(exponentformat="power"), yaxis=dict(exponentformat="power"))
    fig_u.show()

    # v(t)
    fig_v = go.Figure()
    fig_v.add_trace(go.Scatter(x=t, y=H["v"], mode="lines", name="v [m/s]"))
    fig_v.update_layout(title="Velocity", xaxis_title="t [s]", yaxis_title="v [m/s]",
                        xaxis=dict(exponentformat="power"), yaxis=dict(exponentformat="power"))
    fig_v.show()

    # lambda(t)
    fig_l = go.Figure()
    fig_l.add_trace(go.Scatter(x=t, y=H["fc"], mode="lines", name="λ"))
    fig_l.update_layout(title="Contact force", xaxis_title="t [s]", yaxis_title="Fc [N]",
                       xaxis=dict(exponentformat="power"), yaxis=dict(exponentformat="power"))
    fig_l.show()

    # Energy balance: ekin, epot, edis
    fig_ebal = go.Figure()
    fig_ebal.add_trace(go.Scatter(x=t, y=H["ekin"], mode="lines", name="E_kin"))
    fig_ebal.add_trace(go.Scatter(x=t, y=H["epotg"], mode="lines", name="E_potg"))
    fig_ebal.add_trace(go.Scatter(x=t, y=H["edis"], mode="lines", name="E_dis"))
    fig_ebal.add_trace(go.Scatter(x=t, y=H["epotk"], mode="lines", name="E_potk"))
    fig_ebal.add_trace(go.Scatter(x=t, y=H["emech"], mode="lines", name="E_mech", line=dict(color="black", dash="dot")))
    fig_ebal.update_layout(title="Energy balance", xaxis_title="t [s]", yaxis_title="Energy [J]",
                           xaxis=dict(exponentformat="power"), yaxis=dict(exponentformat="power"))
    fig_ebal.show()
   
    # Energies: emech, etot, ealg, etot_alg
    fig_e = go.Figure()
    fig_e.add_trace(go.Scatter(x=t, y=H["etot"]-H["etot"][0], mode="lines", name="E_total", line=dict(color="black")))
    fig_e.add_trace(go.Scatter(x=t, y=H["etot_alg"]-H["etot_alg"][0], mode="lines", name="E_total_algo", line=dict(color="gray")))
    fig_e.update_layout(title="Total energy variation", xaxis_title="t [s]", yaxis_title="Energy variation [J]",
                       xaxis=dict(exponentformat="power"), yaxis=dict(exponentformat="power"))
    fig_e.show()

### Run the simulation

In [None]:
params = Params(m=1.0, r=0.0, g_const=9.81, k=1e4, c=5.0, dt=1e-2, t_end=10)
u0 = 1.0  # [m]
v0 = 0.0  # [m/s]

sim = BouncingBallPenalty(params, u0=u0, v0=v0)
hist = sim.run()

plot_history(hist)