In [None]:
"""
PINN for exponential decay with parameter identification
ODE: dy/dt + k y = 0,  y(0)=y0
Learn both the function y(t) and the decay constant k from sparse (possibly noisy) data.
"""

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


# ----------------------------
# Model
# ----------------------------
class MLP(nn.Module):
    def __init__(self, layers):
        super().__init__()
        net = []
        for i in range(len(layers) - 2):
            net.append(nn.Linear(layers[i], layers[i + 1]))
            net.append(nn.Tanh())
        net.append(nn.Linear(layers[-2], layers[-1]))
        self.net = nn.Sequential(*net)

    def forward(self, t):
        # t: (N, 1)
        return self.net(t)


def make_synthetic_data(k_true=1.7, y0=2.0, t_max=2.0, n_data=20, noise_std=0.02, seed=0):
    rng = np.random.default_rng(seed)
    t_data = rng.uniform(0.0, t_max, size=(n_data, 1))
    t_data = np.sort(t_data, axis=0)

    y_clean = y0 * np.exp(-k_true * t_data)
    y_noisy = y_clean + rng.normal(0.0, noise_std, size=y_clean.shape)

    return t_data, y_noisy, y_clean


def main():
    # Reproducibility
    torch.manual_seed(0)
    np.random.seed(0)

    # True system
    k_true = 1.7
    y0 = 2.0
    t_max = 2.0

    # Data (sparse, noisy)
    t_data_np, y_data_np, _ = make_synthetic_data(
        k_true=k_true, y0=y0, t_max=t_max, n_data=20, noise_std=0.02, seed=0
    )

    # Collocation points for physics loss
    n_f = 200
    t_f_np = np.linspace(0.0, t_max, n_f).reshape(-1, 1)

    # Torch tensors
    device = torch.device("cpu")
    t_data = torch.tensor(t_data_np, dtype=torch.float32, device=device)
    y_data = torch.tensor(y_data_np, dtype=torch.float32, device=device)
    t_f = torch.tensor(t_f_np, dtype=torch.float32, device=device, requires_grad=True)

    t0 = torch.tensor([[0.0]], dtype=torch.float32, device=device)
    y0_t = torch.tensor([[y0]], dtype=torch.float32, device=device)

    # Network + trainable parameter
    model = MLP(layers=[1, 10, 1]).to(device)

    # Enforce k>0 via softplus
    raw_k = nn.Parameter(torch.tensor([0.0], dtype=torch.float32, device=device))  # initial guess
    params = list(model.parameters()) + [raw_k]

    # Loss weights (you can tune these)
    alpha = 1.0   # data
    beta = 1.0    # equation
    gamma = 10.0  # initial condition

    optimizer = torch.optim.Adam(params, lr=1e-3)

    # Training
    epochs = 8000
    k_hist = []
    loss_hist = []

    for ep in range(1, epochs + 1):
        optimizer.zero_grad()

        k_hat = F.softplus(raw_k)  # positive decay constant

        # ---------
        # Physics loss: Leq = mean( (dyhat/dt + k_hat*yhat)^2 )
        # ---------
        y_f = model(t_f)  # (Nf,1)
        dy_dt = torch.autograd.grad(
            outputs=y_f,
            inputs=t_f,
            grad_outputs=torch.ones_like(y_f),
            create_graph=True,
            retain_graph=True,
        )[0]
        residual = dy_dt + k_hat * y_f
        L_eq = torch.mean(residual**2)

        # ---------
        # Initial condition loss: Lic = (yhat(0)-y0)^2
        # ---------
        y_0_pred = model(t0)
        L_ic = torch.mean((y_0_pred - y0_t) ** 2)

        # ---------
        # Data loss: Ldata = mean( (yhat(t_i)-y_i)^2 )
        # ---------
        y_d_pred = model(t_data)
        L_data = torch.mean((y_d_pred - y_data) ** 2)

        # Total
        loss = alpha * L_data + beta * L_eq + gamma * L_ic
        loss.backward()
        optimizer.step()

        # Track
        k_hist.append(float(k_hat.detach().cpu().numpy().squeeze()))
        loss_hist.append(float(loss.detach().cpu().numpy()))

        if ep % 1000 == 0 or ep == 1:
            print(
                f"epoch={ep:5d}  loss={loss_hist[-1]:.3e}  "
                f"L_data={float(L_data.detach().cpu()):.3e}  "
                f"L_eq={float(L_eq.detach().cpu()):.3e}  "
                f"L_ic={float(L_ic.detach().cpu()):.3e}  "
                f"k_hat={k_hist[-1]:.6f}"
            )

    # Evaluation on a dense grid
    t_test_np = np.linspace(0.0, t_max, 300).reshape(-1, 1)
    t_test = torch.tensor(t_test_np, dtype=torch.float32, device=device)
    with torch.no_grad():
        y_pred = model(t_test).cpu().numpy().squeeze()

    y_true = (y0 * np.exp(-k_true * t_test_np)).squeeze()

    print("\nFinal results:")
    print(f"  true k  = {k_true:.6f}")
    print(f"  learned k_hat = {k_hist[-1]:.6f}")

    # Print a few predictions
    sample_times = np.array([0.0, 0.5, 1.0, 1.5, 2.0]).reshape(-1, 1)
    sample_t = torch.tensor(sample_times, dtype=torch.float32, device=device)
    with torch.no_grad():
        sample_pred = model(sample_t).cpu().numpy().squeeze()
    sample_true = (y0 * np.exp(-k_true * sample_times)).squeeze()

    print("\nPredictions at selected times:")
    for tt, yt, yp in zip(sample_times.squeeze(), sample_true, sample_pred):
        print(f"  t={tt:>4.1f}  y_true={yt:.6f}  y_pred={yp:.6f}")

    # ----------------------------
    # Plots: y(t) and k learning
    # ----------------------------
    plt.figure(dpi=300)
    plt.plot(t_test_np, y_true, label="Simulation")
    plt.plot(t_test_np, y_pred, label="PINN")
    plt.scatter(t_data_np, y_data_np.squeeze(), label="Data")
    plt.xlabel(r"\(t\)")
    plt.ylabel(r"\(y(t)\)")
    plt.legend()
    # plt.title("Exponential decay: PINN prediction")

    plt.figure(dpi=300)
    plt.plot(np.arange(1, epochs + 1), k_hist, label="k_hat")
    plt.axhline(k_true, linestyle="--", label="true k")
    plt.xlabel("Epochs")
    plt.ylabel(r"Decay Constant \(k\)")
    plt.legend()
    # plt.title("Learning the decay constant")

    plt.show()


if __name__ == "__main__":
    main()
