In [9]:
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import csv
from dataclasses import dataclass

# Datasets (Gamma functions)
def f1(x):  # sin(10x) on [-1, 1]
    return jnp.sin(10.0 * x)

def f2(x):  # Runge function on [-5, 5]
    return 1.0 / (1.0 + 25.0 * x**2)

def f3(x):  # Step function on [-2, 2]: 1 in [0,2], else 0
    return jnp.where((x >= 0.0) & (x <= 2.0), 1.0, 0.0)

@dataclass
class Dataset:
    f: callable
    a: float
    b: float
    name: str

DATASETS = [
    Dataset(f=f1, a=-1.0, b= 1.0, name="sin(10x)"),
    Dataset(f=f2, a=-5.0, b= 5.0, name="Runge function"),
    Dataset(f=f3, a=-2.0, b= 2.0, name="Heaviside step function"),
]

# Losses
def huber_loss(r, delta=1.0):
    # 0.5 r^2 for |r|<=delta, delta(|r| - 0.5 delta) otherwise
    absr = jnp.abs(r)
    quad = 0.5 * r**2
    lin = delta * (absr - 0.5 * delta)
    return jnp.where(absr <= delta, quad, lin)

def lp_loss(r, p=1.5):
    return jnp.abs(r)**p

# Names for plotting/CSV
LOSS_SPECS = [
    ("huber", lambda r: huber_loss(r, delta=1.0)),
    ("lp",    lambda r: lp_loss(r, p=1.5)),
]

# Utilities
def normalize_to_unit_interval(x, a, b):
    return 2.0 * (x - a) / (b - a) - 1.0

def vandermonde(x_norm, degree):
    return jnp.vander(x_norm, degree + 1, increasing=True)

def initial_theta(A, y):
    try:
        return jnp.linalg.lstsq(A, y, rcond=None)[0]
    except AttributeError:
        return jnp.linalg.pinv(A) @ y

# IRLS core (weights per assignment: w = rho(r) / r^2)
def irls(A, y, rho_fn, max_iter=1000, tol=1e-8, eps=1e-8):
    """
    Solve min sum_i rho(r_i) via IRLS:
      - r = y - A @ theta
      - w_i = rho(r_i) / (r_i^2 + eps)
      - Solve weighted LS: min_theta || diag(sqrt(w)) (y - A theta) ||^2
    """
    theta = initial_theta(A, y)
    for _ in range(max_iter):
        r = y - A @ theta
        rho_val = rho_fn(r)
        w = rho_val / (r**2 + eps)
        # Weighted least squares: minimize || W^(1/2) (y - A theta) ||_2
        # Equivalent normal eq: (A^T W A) theta = A^T W y
        # We'll form W*A and W*y via diag weights:
        sqrtw = jnp.sqrt(w)
        Aw = A * sqrtw[:, None]
        yw = y * sqrtw
        try:
            theta_new = jnp.linalg.lstsq(Aw, yw, rcond=None)[0]
        except AttributeError:
            theta_new = jnp.linalg.pinv(Aw) @ yw

        if jnp.linalg.norm(theta_new - theta) < tol:
            return theta_new
        theta = theta_new
    return theta  # reached max_iter

# Experiment configuration
m = 100
degrees = list(range(1, 31))  # 1..30
num_test = 1000
key = random.PRNGKey(0)

# Run experiments
results = []  # rows: (dataset, loss, degree, max_error)

for (loss_name, rho_fn) in LOSS_SPECS:
    for ds in DATASETS:
        a, b, f, name = ds.a, ds.b, ds.f, ds.name

        # Training data
        key, subkey = random.split(key)

        x_train = jnp.linspace(a, b, m)
        y_train = f(x_train)

        # Normalize x to [-1,1] for stability; build test grid too
        x_train_norm = normalize_to_unit_interval(x_train, a, b)

        #x_test = jnp.linspace(a, b, num_test)
        x_test = random.uniform(subkey, shape=(m,), minval=a, maxval=b)
        x_test_norm = normalize_to_unit_interval(x_test, a, b)
        y_test_true = f(x_test)

        # Sweep degrees
        for n in degrees:
            # Design matrices
            A = vandermonde(x_train_norm, n)
            A_test = vandermonde(x_test_norm, n)

            # Initial LS (warm start), then IRLS
            theta0 = initial_theta(A, y_train)
            # IRLS using the assignment's weight rule
            theta = irls(A, y_train, rho_fn=rho_fn, max_iter=1000, tol=1e-8, eps=1e-8)

            # Evaluate required metric
            y_pred = A_test @ theta
            err_max = jnp.max(jnp.abs(y_test_true - y_pred))
            results.append((name, loss_name, int(n), float(err_max)))

            # Console progress
            print(f"{name:26s} | loss={loss_name:5s} | n={n:2d} | max_err={float(err_max):.6e}")


# Save CSV of results
csv_path = "exercise3_results.csv"
with open(csv_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["dataset", "loss", "degree_n", "max_error"])
    writer.writerows(results)
print(f"\nSaved results to {csv_path}")

# Plot error vs degree
import pandas as pd

df = pd.DataFrame(results, columns=["dataset", "loss", "degree_n", "max_error"])
for ds in DATASETS:
    sub = df[df["dataset"] == ds.name].copy()
    # Pivot to have columns per loss
    pivot = sub.pivot(index="degree_n", columns="loss", values="max_error").sort_index()

    plt.figure(figsize=(5, 4))
    if "huber" in pivot.columns:
        plt.plot(pivot.index.values, pivot["huber"].values, marker="o", label="Huber")
    if "lp" in pivot.columns:
        plt.plot(pivot.index.values, pivot["lp"].values, marker="s", label="L^1.5")

    plt.title(f"Max Error vs Degree — {ds.name}")
    plt.xlabel("Polynomial degree n")
    plt.ylabel("Max error on 1000 test points")
    plt.grid(True, alpha=0.4)
    plt.legend()
    out_png = f"exercise3_error_vs_degree_{ds.name.replace(' ', '_').replace('(', '').replace(')', '')}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()
    print(f"Saved plot: {out_png}")

sin(10x)                   | loss=huber | n= 1 | max_err=1.096320e+00
sin(10x)                   | loss=huber | n= 2 | max_err=1.096320e+00
sin(10x)                   | loss=huber | n= 3 | max_err=1.012815e+00
sin(10x)                   | loss=huber | n= 4 | max_err=1.012815e+00
sin(10x)                   | loss=huber | n= 5 | max_err=1.198919e+00
sin(10x)                   | loss=huber | n= 6 | max_err=1.198919e+00
sin(10x)                   | loss=huber | n= 7 | max_err=7.508762e-01
sin(10x)                   | loss=huber | n= 8 | max_err=7.508768e-01
sin(10x)                   | loss=huber | n= 9 | max_err=2.921317e-01
sin(10x)                   | loss=huber | n=10 | max_err=2.920859e-01
sin(10x)                   | loss=huber | n=11 | max_err=7.854843e-02
sin(10x)                   | loss=huber | n=12 | max_err=7.851791e-02
sin(10x)                   | loss=huber | n=13 | max_err=1.227188e-02
sin(10x)                   | loss=huber | n=14 | max_err=1.239395e-02
sin(10x)            