In [1]:
import numpy as np
from scipy.interpolate import CubicSpline
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter


def linear_interpolate(x_data, y_data, x_query, extrapolation='constant'):
    """
    Perform linear interpolation on data points.

    Given a set of data points (x_data, y_data), this function calculates
    the interpolated y-values at specified x-positions (x_query) using
    piecewise linear interpolation.

    Parameters:
    -----------
    x_data : list or tuple
        x-coordinates of input data points
    y_data : list or tuple
        y-coordinates of input data points
    x_query : float or list
        x value(s) where interpolation is desired
    extrapolation : str, optional
        How to handle points outside the data range:
        - 'forbid': raise error if outside domain
        - 'constant': use endpoint value
        - 'linear': continue with same slope

    Returns:
    --------
    float or list: Interpolated y value(s)

    Method:
    -------
    For each query point x:
      1. Find the two data points that bracket x
      2. Calculate the slope between these points
      3. Use the linear equation: y = y₀ + slope × (x - x₀)
    """

    # Convert to lists for processing
    x_data = list(x_data)
    y_data = list(y_data)

    # Check if data is sorted, sort if needed
    is_sorted = True
    for i in range(len(x_data) - 1):
        if x_data[i] > x_data[i + 1]:
            is_sorted = False
            break

    if not is_sorted:
        # Sort both lists together based on x values
        pairs = list(zip(x_data, y_data))
        pairs.sort(key=lambda pair: pair[0])
        x_data = [pair[0] for pair in pairs]
        y_data = [pair[1] for pair in pairs]

    # Handle both scalar and list inputs
    if isinstance(x_query, (int, float)):
        scalar_input = True
        x_query = [x_query]
    else:
        scalar_input = False
        x_query = list(x_query)

    # Initialize results
    y_interp = []

    x_min = x_data[0]
    x_max = x_data[-1]

    # Process each query point
    for xq in x_query:
        # Check if outside domain
        if xq < x_min or xq > x_max:
            if extrapolation == 'forbid':
                raise ValueError(
                    f"x_query value {xq} is outside the domain "
                    f"[{x_min}, {x_max}]. Extrapolation is forbidden."
                )
            elif extrapolation == 'constant':
                # Use constant value from nearest endpoint
                if xq < x_min:
                    y_interp.append(y_data[0])
                else:
                    y_interp.append(y_data[-1])
            elif extrapolation == 'linear':
                # Linear extrapolation using slope from nearest segment
                if xq < x_min:
                    slope = (y_data[1] - y_data[0]) / (x_data[1] - x_data[0])
                    y = y_data[0] + slope * (xq - x_data[0])
                    y_interp.append(y)
                else:
                    slope = (y_data[-1] - y_data[-2]) / (x_data[-1] - x_data[-2])
                    y = y_data[-1] + slope * (xq - x_data[-1])
                    y_interp.append(y)
            else:
                raise ValueError(f"Unknown extrapolation method: '{extrapolation}'")
        else:
            # Find bracketing interval
            idx = None
            for i in range(len(x_data)):
                if x_data[i] >= xq:
                    idx = i
                    break

            # Handle exact match or boundary
            if idx is not None and x_data[idx] == xq:
                # Exact match with a data point
                y_interp.append(y_data[idx])
            elif idx == 0 or idx is None:
                # At or before first point
                y_interp.append(y_data[0])
            else:
                # Between two points: apply linear interpolation formula
                x0 = x_data[idx - 1]
                x1 = x_data[idx]
                y0 = y_data[idx - 1]
                y1 = y_data[idx]

                # Linear interpolation: y = y₀ + (y₁ - y₀) × (x - x₀) / (x₁ - x₀)
                slope = (y1 - y0) / (x1 - x0)
                y = y0 + slope * (xq - x0)
                y_interp.append(y)

    # Return scalar if input was scalar, list if input was list
    if scalar_input:
        return y_interp[0]
    else:
        return y_interp
