In [25]:
# EXPERIMENT 3: Physics Formula Approximation
# SHM: x(t) = A * cos(ω t)
# Compare: MLP vs KAN (with and without auxiliary θ = ω t)

import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import r2_score

torch.manual_seed(0)

<torch._C.Generator at 0x10f99e630>

In [26]:
# Genarate Synthetic Data
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Using device:", device)


def generate_shm_dataset(
    n_samples=6000,
    A_range=(0.5, 5.0),
    w_range=(0.5, 10.0),
    t_range=(0.0, 5.0),
    noise_std=0.0,
    device=device
):
    """
    Generate samples of x(t) = A * cos(ω t) with optional Gaussian noise.
    Returns:
        A, w, t, x  (each of shape [N, 1], on CPU)
    """
    A = torch.empty(n_samples, 1).uniform_(*A_range)
    w = torch.empty(n_samples, 1).uniform_(*w_range)
    t = torch.empty(n_samples, 1).uniform_(*t_range)

    x = A * torch.cos(w * t)

    if noise_std > 0:
        x = x + noise_std * torch.randn_like(x)

    return A, w, t, x

# Generate data
A, w, t, x = generate_shm_dataset(
    n_samples=6000,
    A_range=(0.5, 5.0),
    w_range=(0.5, 10.0),
    t_range=(0.0, 5.0),
    noise_std=0.0,  # set >0 to test robustness
)

# Train / test split
idx = torch.randperm(len(x))
n_train = 5000
train_idx = idx[:n_train]
test_idx  = idx[n_train:]

A_tr, w_tr, t_tr, x_tr = A[train_idx], w[train_idx], t[train_idx], x[train_idx]
A_te, w_te, t_te, x_te = A[test_idx],  w[test_idx],  t[test_idx],  x[test_idx]

# Push labels to device once
y_tr = x_tr.to(device)
y_te = x_te.to(device)

print("Train samples:", len(y_tr), "| Test samples:", len(y_te))


Using device: mps
Train samples: 5000 | Test samples: 1000


In [27]:
# 2 Helper Metrics

def rmse(y_true, y_pred):
    return torch.sqrt(nn.MSELoss()(y_pred, y_true)).item()

def r2_score_torch(y_true, y_pred):
    # y_* should be CPU numpy
    return r2_score(y_true.detach().cpu().numpy(),
                    y_pred.detach().cpu().numpy())



In [28]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Using device:", device)


# 3) MLP Baseline 
class SHMMLP(nn.Module):
    def __init__(self, in_dim=3, hidden=64):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, x):
        return self.layers(x)

mlp = SHMMLP().to(device)
opt_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

Xtr_mlp = torch.cat([A_tr, w_tr, t_tr], dim=1).to(device)
Xte_mlp = torch.cat([A_te, w_te, t_te], dim=1).to(device)

EPOCHS_MLP = 200
for ep in range(1, EPOCHS_MLP + 1):
    mlp.train()
    opt_mlp.zero_grad()
    pred = mlp(Xtr_mlp)
    loss = loss_fn(pred, y_tr)
    loss.backward()
    opt_mlp.step()
    if ep % 50 == 0:
        print(f"[MLP] epoch {ep}/{EPOCHS_MLP} | train_loss={loss.item():.6f}")

mlp.eval()
with torch.no_grad():
    y_hat_mlp = mlp(Xte_mlp)

mlp_rmse = rmse(y_te, y_hat_mlp)
mlp_r2   = r2_score_torch(y_te, y_hat_mlp)

print("\n=== MLP baseline (A, ω, t) ===")
print(f"RMSE: {mlp_rmse:.4f}")
print(f"R²  : {mlp_r2:.4f}")

Using device: mps
[MLP] epoch 50/200 | train_loss=4.417974
[MLP] epoch 100/200 | train_loss=4.213521
[MLP] epoch 150/200 | train_loss=4.057635
[MLP] epoch 200/200 | train_loss=3.959777

=== MLP baseline (A, ω, t) ===
RMSE: 2.0393
R²  : 0.1343


In [30]:
# 4 KAN Experiments
import os
os.environ.pop("MPLBACKEND", None)  

import matplotlib
matplotlib.use("Agg")              

import torch
import kan

try:
    #KAN class is directly under kan
    KAN = kan.KAN
except AttributeError:
    try:
        # kan.KAN is a module that contains the class KAN
        KAN = kan.KAN.KAN
    except AttributeError:
        # Fallback
        from kan.MultKAN import MultKAN as KAN

# θ = ω t
theta_tr = (w_tr * t_tr).to(device)
theta_te = (w_te * t_te).to(device)

Xtr_aux = torch.cat([A_tr.to(device), theta_tr], dim=1)  # shape [N,2]
Xte_aux = torch.cat([A_te.to(device), theta_te], dim=1)

kan_aux = KAN(
    width=[2, 32, 1],  # input dim 2: (A, θ)
    grid=5,
    k=3,
    seed=0,
    device=device,
)

dataset_aux = {
    "train_input": Xtr_aux,
    "train_label": y_tr,
    "test_input":  Xte_aux,
    "test_label":  y_te,
}

print("\nTraining KAN with auxiliary θ = ω t ...")
kan_aux.fit(dataset_aux, opt="LBFGS", steps=80, lamb=1e-6, lamb_l1=0.0, lamb_entropy=0.0, 
                update_grid=False,          # <--- add this
)

with torch.no_grad():
    y_hat_aux = kan_aux(Xte_aux)

kan_aux_rmse = rmse(y_te, y_hat_aux)
kan_aux_r2   = r2_score_torch(y_te, y_hat_aux)

print("\n=== KAN (with θ = ω t) ===")
print(f"RMSE: {kan_aux_rmse:.4f}")
print(f"R²  : {kan_aux_r2:.4f}")


checkpoint directory created: ./model
saving model version 0.0

Training KAN with auxiliary θ = ω t ...


| train_loss: 1.46e+00 | test_loss: 1.54e+00 | reg: 0.00e+00 | : 100%|█| 80/80 [01:11<00:00,  1.11it

saving model version 0.1

=== KAN (with θ = ω t) ===
RMSE: 1.5359
R²  : 0.5090





In [32]:

#  A_tr, w_tr, t_tr, x_tr, A_te, w_te, t_te, x_te are defined

Xtr_raw = torch.cat([A_tr, w_tr, t_tr], dim=1).to(device)
Xte_raw = torch.cat([A_te, w_te, t_te], dim=1).to(device)

kan_raw = KAN(width=[3, 32, 1], grid=5, k=3, seed=0, device=device)

dataset_raw = {
    "train_input": Xtr_raw,
    "train_label": x_tr.to(device),
    "test_input":  Xte_raw,
    "test_label":  x_te.to(device),
}

#  update_grid=False 
kan_raw.fit(dataset_raw,
            opt="LBFGS",
            steps=80,
            lamb=0.0,
            lamb_l1=0.0,
            lamb_entropy=0.0,
            update_grid=False)

with torch.no_grad():
    y_hat_raw = kan_raw(Xte_raw).cpu()

rmse = torch.sqrt(torch.nn.MSELoss()(y_hat_raw, x_te))

print("KAN (raw A, ω, t) — RMSE:", rmse.item())


checkpoint directory created: ./model
saving model version 0.0


| train_loss: 1.49e+00 | test_loss: 1.62e+00 | reg: 0.00e+00 | : 100%|█| 80/80 [00:59<00:00,  1.35it

saving model version 0.1
KAN (raw A, ω, t) — RMSE: 1.616328239440918



