In [3]:
import numpy as np
from scipy.interpolate import CubicSpline
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter


def linear_interpolate(x_data, y_data, x_query, extrapolation='constant'):
    """
    Perform linear interpolation on data points.

    Given a set of data points (x_data, y_data), this function calculates
    the interpolated y-values at specified x-positions (x_query) using
    piecewise linear interpolation.

    Parameters:
    -----------
    x_data : list or tuple
        x-coordinates of input data points
    y_data : list or tuple
        y-coordinates of input data points
    x_query : float or list
        x value(s) where interpolation is desired
    extrapolation : str, optional
        How to handle points outside the data range:
        - 'forbid': raise error if outside domain
        - 'constant': use endpoint value
        - 'linear': continue with same slope

    Returns:
    --------
    float or list: Interpolated y value(s)

    Method:
    -------
    For each query point x:
      1. Find the two data points that bracket x
      2. Calculate the slope between these points
      3. Use the linear equation: y = y₀ + slope × (x - x₀)
    """

    # Convert to lists for processing
    x_data = list(x_data)
    y_data = list(y_data)

    # Check if data is sorted, sort if needed
    is_sorted = True
    for i in range(len(x_data) - 1):
        if x_data[i] > x_data[i + 1]:
            is_sorted = False
            break

    if not is_sorted:
        # Sort both lists together based on x values
        pairs = list(zip(x_data, y_data))
        pairs.sort(key=lambda pair: pair[0])
        x_data = [pair[0] for pair in pairs]
        y_data = [pair[1] for pair in pairs]

    # Handle both scalar and list inputs
    if isinstance(x_query, (int, float)):
        scalar_input = True
        x_query = [x_query]
    else:
        scalar_input = False
        x_query = list(x_query)

    # Initialize results
    y_interp = []

    x_min = x_data[0]
    x_max = x_data[-1]

    # Process each query point
    for xq in x_query:
        # Check if outside domain
        if xq < x_min or xq > x_max:
            if extrapolation == 'forbid':
                raise ValueError(
                    f"x_query value {xq} is outside the domain "
                    f"[{x_min}, {x_max}]. Extrapolation is forbidden."
                )
            elif extrapolation == 'constant':
                # Use constant value from nearest endpoint
                if xq < x_min:
                    y_interp.append(y_data[0])
                else:
                    y_interp.append(y_data[-1])
            elif extrapolation == 'linear':
                # Linear extrapolation using slope from nearest segment
                if xq < x_min:
                    slope = (y_data[1] - y_data[0]) / (x_data[1] - x_data[0])
                    y = y_data[0] + slope * (xq - x_data[0])
                    y_interp.append(y)
                else:
                    slope = (y_data[-1] - y_data[-2]) / (x_data[-1] - x_data[-2])
                    y = y_data[-1] + slope * (xq - x_data[-1])
                    y_interp.append(y)
            else:
                raise ValueError(f"Unknown extrapolation method: '{extrapolation}'")
        else:
            # Find bracketing interval
            idx = None
            for i in range(len(x_data)):
                if x_data[i] >= xq:
                    idx = i
                    break

            # Handle exact match or boundary
            if idx is not None and x_data[idx] == xq:
                # Exact match with a data point
                y_interp.append(y_data[idx])
            elif idx == 0 or idx is None:
                # At or before first point
                y_interp.append(y_data[0])
            else:
                # Between two points: apply linear interpolation formula
                x0 = x_data[idx - 1]
                x1 = x_data[idx]
                y0 = y_data[idx - 1]
                y1 = y_data[idx]

                # Linear interpolation: y = y₀ + (y₁ - y₀) × (x - x₀) / (x₁ - x₀)
                slope = (y1 - y0) / (x1 - x0)
                y = y0 + slope * (xq - x0)
                y_interp.append(y)

    # Return scalar if input was scalar, list if input was list
    if scalar_input:
        return y_interp[0]
    else:
        return y_interp
def format_sci_notation(value, precision=2):
    """
    Convert scientific notation to superscript format,
    e.g. 1.23e-04 → '1.23 × 10⁻⁴'
    """
    sci_str = f"{value:.{precision}e}"
    if 'e' in sci_str:
        coeff, exp = sci_str.split('e')
        exp_int = int(exp)

        superscript_map = {
            '0': '⁰', '1': '¹', '2': '²', '3': '³', '4': '⁴',
            '5': '⁵', '6': '⁶', '7': '⁷', '8': '⁸', '9': '⁹',
            '-': '⁻', '+': '⁺'
        }
        exp_super = ''.join(superscript_map.get(c, c) for c in str(exp_int))
        return f"{coeff} × 10{exp_super}"
    else:
        return sci_str
# ======================================================================
# MAIN: Harmonic oscillator energy with linear vs periodic cubic spline
# ======================================================================

def main():
    print("=" * 70)
    print("SIMPLE HARMONIC OSCILLATOR: ENERGY WITH LINEAR VS PERIODIC CUBIC SPLINE")
    print("=" * 70)

    # Parameters
    m = 1.0          # mass
    k = 1.0          # spring constant
    omega = np.sqrt(k/m)
    T = 2 * np.pi / omega   # one complete period
    N = 2000        # number of sample points
    A = 1.0         # amplitude

    # Sample times and exact SHM solution: x(t) = A cos(omega t)
    t_samples = np.linspace(0, T, N)
    x_samples = A * np.cos(omega * t_samples)
    v_samples = -A * omega * np.sin(omega * t_samples)

    # Fine time grid for evaluation
    t_fine = np.linspace(0, T, 5001)

    # Exact solution on fine grid
    x_true = A * np.cos(omega * t_fine)
    v_true = -A * omega * np.sin(omega * t_fine)
    E_true = 0.5 * m * v_true**2 + 0.5 * k * x_true**2

    # ============== LINEAR INTERPOLATION (CUSTOM FUNCTION) ==============
    print("\nPerforming linear interpolation using custom function...")
    print(f"Number of sample points: {N}")
    x_lin = np.array(linear_interpolate(t_samples, x_samples, list(t_fine)))
    v_lin = np.gradient(x_lin, t_fine)
    E_lin = 0.5 * m * v_lin**2 + 0.5 * k * x_lin**2
    KE_lin = 0.5 * m * v_lin**2
    PE_lin = 0.5 * k * x_lin**2

    # ============== CUBIC SPLINE INTERPOLATION (PERIODIC) ==============
    print("Performing cubic spline interpolation with bc_type='periodic'...")
    cs = CubicSpline(t_samples, x_samples, bc_type='periodic')
    x_cub = cs(t_fine)
    v_cub = cs(t_fine, 1)
    E_cub = 0.5 * m * v_cub**2 + 0.5 * k * x_cub**2
    KE_cub = 0.5 * m * v_cub**2
    PE_cub = 0.5 * k * x_cub**2

    # ============== COMPUTE ENERGY STATISTICS ==============
    E_true_mean = np.mean(E_true)
    E_lin_mean = np.mean(E_lin)
    E_cub_mean = np.mean(E_cub)

    E_true_var = E_true.max() - E_true.min()
    E_lin_var = E_lin.max() - E_lin.min()
    E_cub_var = E_cub.max() - E_cub.min()

    E_true_std = np.std(E_true)
    E_lin_std = np.std(E_lin)
    E_cub_std = np.std(E_cub)

    print("\n" + "-" * 70)
    print(f"ENERGY STATISTICS (mass=1, spring const=1, period=1, sample points={N})")
    print("-" * 70)
    print(f"\n{'Method':<30} {'Mean':<15} {'Variation':<20} {'Std Dev':<15}")
    print("-" * 70)
    print(f"{'Exact solution':<30} {E_true_mean:<15.6f} {E_true_var:<20.3e} {E_true_std:<15.3e}")
    print(f"{'Linear interpolation':<30} {E_lin_mean:<15.6f} {E_lin_var:<20.3e} {E_lin_std:<15.3e}")
    print(f"{'Cubic spline (periodic)':<30} {E_cub_mean:<15.6f} {E_cub_var:<20.3e} {E_cub_std:<15.3e}")
    print("-" * 70)

    print("\nCONCLUSION:")
    print(f"  Linear interpolation ΔE = {format_sci_notation(E_lin_var)} J")
    print(f"  Periodic cubic spline ΔE = {format_sci_notation(E_cub_var)} J")
    if E_cub_var > 0:
        print(f"  Cubic spline is about {E_lin_var / E_cub_var:.1f}× better (smaller energy variation)")

    # =================== PLOTS ===================
