In [7]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# ==== Load and split data ====
data_dir=
data_tensor = torch.load("./PINN_test/progress/testing/data_test.pt")  # Shape: (N, 7)
X_data = data_tensor[:, :5].numpy()
Y_data = data_tensor[:, 5:].numpy()

In [8]:
# ==== Apply StandardScaler ====
X_scaler = StandardScaler()
y_scaler = StandardScaler()

X_scaled = X_scaler.fit_transform(X_data)
Y_scaled = y_scaler.fit_transform(Y_data)

X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
Y_tensor = torch.tensor(Y_scaled, dtype=torch.float32)

# ==== Prepare DataLoader ====
dataset = TensorDataset(X_tensor, Y_tensor)
loader = DataLoader(dataset, batch_size=1024, shuffle=True)

In [13]:
# ==== Define PINN ====
class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(5, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 2)  # output: [omega, phi]
        )

    def forward(self, x):
        return self.net(x)

    def physics_loss(self, x_scaled, X_scaler, y_scaler):
        x_scaled = x_scaled.clone().detach().requires_grad_(True)
        y_pred_scaled = self.forward(x_scaled)

        d_y_scaled = torch.autograd.grad(
            outputs=y_pred_scaled,
            inputs=x_scaled,
            grad_outputs=torch.ones_like(y_pred_scaled),
            create_graph=True
        )[0]

        y_std = torch.tensor(y_scaler.scale_, dtype=torch.float32, device=x_scaled.device)
        x_std = torch.tensor(X_scaler.scale_, dtype=torch.float32, device=x_scaled.device)
        correction = y_std[:, None] / x_std[None, :]  # [2, 5]
        d_y_unscaled = d_y_scaled.unsqueeze(1) * correction  # [N, 2, 5]

        d_omega_dt = d_y_unscaled[:, 0, 0:1]
        d_phi_dt = d_y_unscaled[:, 1, 0:1]

        # also compute ground truth from true_system
        x_unscaled = inverse_transform_tensor(x_scaled, X_scaler)
        y_unscaled = inverse_transform_tensor(y_pred_scaled, y_scaler)

        t = x_unscaled[:, 0:1]
        Mc = x_unscaled[:, 3:4]
        eta = x_unscaled[:, 4:5]
        omega = y_unscaled[:, 0:1]
        phi = y_unscaled[:, 1:2]

        d_omega_true, d_phi_true = true_system(t, [omega, phi], Mc, eta)

        criterion = nn.MSELoss()
        ode_loss_omega = criterion(d_omega_dt, d_omega_true)
        ode_loss_phi = criterion(d_phi_dt, d_phi_true)

        return ode_loss_omega, ode_loss_phi, d_omega_dt, d_phi_dt

In [14]:
def inverse_transform_tensor(tensor_scaled, scaler):
    device = tensor_scaled.device
    dtype = tensor_scaled.dtype
    scale = torch.tensor(scaler.scale_, dtype=dtype, device=device)
    mean = torch.tensor(scaler.mean_, dtype=dtype, device=device)
    return tensor_scaled * scale + mean

In [15]:
# ==== Define true physical model ====
tsun = 4.92549095e-6  # seconds

def true_system(t, y, Mc, eta):
    omega, phi = y
    tN_omega = Mc * omega * tsun
    d_omega = (96 / 5) * omega**2 * tN_omega**(5 / 3)
    d_phi = omega
    return [d_omega, d_phi]



In [16]:
model = PINN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    for xb, yb in loader:
        pred = model(xb)
        loss_data = criterion(pred, yb)

        ode_loss_omega, ode_loss_phi, d_omega_dt, d_phi_dt = model.physics_loss(xb, X_scaler, y_scaler)
        ode_loss = ode_loss_omega + ode_loss_phi

        loss = loss_data + 0.1 * ode_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Total={total_loss / len(loader):.6f}, Data={loss_data.item():.6f}, ODE ω={ode_loss_omega.item():.6f}, ODE φ={ode_loss_phi.item():.6f}")


Epoch 1: Total=1.416711, Data=0.856911, ODE ω=0.464880, ODE φ=0.530030
Epoch 2: Total=0.798628, Data=0.562594, ODE ω=0.684869, ODE φ=0.461203
Epoch 3: Total=0.552817, Data=0.317470, ODE ω=0.409969, ODE φ=1.154492
Epoch 4: Total=0.470870, Data=0.273974, ODE ω=0.503526, ODE φ=1.378404
Epoch 5: Total=0.447218, Data=0.262666, ODE ω=0.596990, ODE φ=1.312486
Epoch 6: Total=0.433727, Data=0.249572, ODE ω=0.525810, ODE φ=1.365673
Epoch 7: Total=0.417597, Data=0.251146, ODE ω=0.583335, ODE φ=1.090333
Epoch 8: Total=0.402581, Data=0.243582, ODE ω=0.545091, ODE φ=1.095857
Epoch 9: Total=0.387124, Data=0.242794, ODE ω=0.554257, ODE φ=0.973990
Epoch 10: Total=0.369097, Data=0.196302, ODE ω=0.552814, ODE φ=0.822364
