In [1]:
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from numpy.polynomial.chebyshev import chebvander
import numpy as onp

# -------- functions --------
def f1(x): return jnp.sin(10 * x)
def f2(x): return 1.0 / (1.0 + 25.0 * x**2)
def f3(x): return jnp.where((x >= 0.0) & (x <= 2.0), 1.0, 0.0)

funcs = [
    ("f1: sin(10x)", f1, -1.0,  1.0),
    ("f2: Runge function", f2, -5.0,  5.0),
    ("f3: Heaviside step function", f3, -2.0,  2.0),
]

def model(theta, x):
    powers = jnp.arange(len(theta))
    return jnp.sum(theta * (x[:, None] ** powers), axis=1)

def rmse(y_true, y_pred):
    return jnp.sqrt(jnp.mean((y_true - y_pred)**2))

def r2_score(y_true, y_pred):
    ss_res = jnp.sum((y_true - y_pred)**2)
    ss_tot = jnp.sum((y_true - jnp.mean(y_true))**2)
    return jnp.where(ss_tot == 0, 1.0, 1.0 - ss_res / ss_tot)

def cond_number(X):
    return float(jnp.linalg.cond(X))


# -------- settings --------
key = random.PRNGKey(0)
m_values = [50, 200, 300]
n_values = list(range(1, 31))  # degrees 1..30

# -------- main --------
for m in m_values:
    # precompute metrics for each function
    all_sig_mono, all_r2_mono, all_kappa_mono = [], [], []
    all_sig_cheb, all_r2_cheb, all_kappa_cheb = [], [], []
    func_names = []

    for fname, f, a, b in funcs:
        # sample data for this function & m
        key, subkey = random.split(key)
        x_data = random.uniform(subkey, shape=(m,), minval=a, maxval=b)
        #x_data = jnp.linspace(a, b, m)
        y_data = f(x_data)  # true y from original x

        # normalize x for stability (basis only)
        x_norm = 2.0 * (x_data - a) / (b - a) - 1

        sig_mono, r2_mono, kap_mono = [], [], []
        sig_cheb, r2_cheb, kap_cheb = [], [], []

        for n in n_values:
            # ----- Monomial basis -----
            X_mono = jnp.vander(x_norm, N=n + 1, increasing=True)
            theta_m = jnp.linalg.pinv(X_mono) @ y_data
            '''
            try:
                theta_m = jnp.linalg.lstsq(X_mono, y_data, rcond=None)[0]
            except AttributeError:
                theta_m = jnp.linalg.pinv(X_mono) @ y_data
            '''
            y_hat_m = model(theta_m, x_norm)
            sig_mono.append(float(rmse(y_data, y_hat_m)))
            r2_mono.append(float(r2_score(y_data, y_hat_m)))
            kap_mono.append(cond_number(X_mono))


            # ----- Chebyshev basis -----
            #X_cheb = chebvander(x_norm, n)
            X_cheb = jnp.asarray(chebvander(onp.asarray(x_norm), n))
            theta_c = jnp.linalg.pinv(X_cheb) @ y_data
            '''
            try:
                theta_c = jnp.linalg.lstsq(X_cheb, y_data, rcond=None)[0]
            except AttributeError:
                theta_c = jnp.linalg.pinv(X_cheb) @ y_data
            '''
            # predict using Chebyshev columns (linear comb of columns)
            y_hat_c = X_cheb @ theta_c
            sig_cheb.append(float(rmse(y_data, y_hat_c)))
            r2_cheb.append(float(r2_score(y_data, y_hat_c)))
            kap_cheb.append(cond_number(X_cheb))


        all_sig_mono.append(sig_mono); all_r2_mono.append(r2_mono); all_kappa_mono.append(kap_mono)
        all_sig_cheb.append(sig_cheb); all_r2_cheb.append(r2_cheb); all_kappa_cheb.append(kap_cheb)
        func_names.append(fname)

    # ---- Plotting ----
    fig, axes = plt.subplots(3, 3, figsize=(16, 14), sharex=True)
    fig.suptitle(f"m = {m}: Monomial (Normalized) vs Chebyshev — σ, R², and κ vs degree", fontsize=18)

    for i, fname in enumerate(func_names):
        # RMSE (σ)
        ax_rmse = axes[i, 0]
        ax_rmse.plot(n_values, all_sig_mono[i], marker='o', label="Monomial (Normalized)")
        ax_rmse.plot(n_values, all_sig_cheb[i], marker='s', label="Chebyshev")
        ax_rmse.set_xlabel("Polynomial degree (n)")
        ax_rmse.set_ylabel("σ (RMSE)")
        ax_rmse.set_title(f"{fname} — RMSE, (m = {m})")
        ax_rmse.grid(True, alpha=0.3)
        ax_rmse.legend()

        # R²
        ax_r2 = axes[i, 1]
        ax_r2.plot(n_values, all_r2_mono[i], marker='o', label="Monomial (Normalized)")
        ax_r2.plot(n_values, all_r2_cheb[i], marker='s', label="Chebyshev")
        ax_r2.set_xlabel("Polynomial degree (n)")
        ax_r2.set_ylabel("R²")
        ax_r2.set_title(f"{fname} — R², (m = {m})")
        ax_r2.grid(True, alpha=0.3)
        ax_r2.legend()

        # Condition number κ(X)
        ax3 = axes[i, 2]
        ax3.plot(n_values, all_kappa_mono[i], marker='o', label="Monomial (Normalized)")
        ax3.plot(n_values, all_kappa_cheb[i], marker='s', label="Chebyshev")
        ax3.set_xlabel("Polynomial degree (n)")
        ax3.set_ylabel("Condition number")
        ax3.set_title(f"{fname} — Condition number, (m = {m})")
        ax3.set_yscale("log")  # κ often spans many orders; log-scale helps
        ax3.grid(True, which="both", alpha=0.3)
        ax3.legend()

    # set x-labels only on bottom row
    #axes[2, 0].set_xlabel("Polynomial degree (n)")
    #axes[2, 1].set_xlabel("Polynomial degree (n)")

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(f"results_m{m}.png", dpi=300)
    plt.close(fig)

In [None]:
import subprocess
subprocess.run(["git","add","fig/"], check=False)
subprocess.run(["git","commit","-m","auto: update figures"], check=False)
subprocess.run(["git","push","overleaf","HEAD:master"], check=False)