# Vanilla PINN

## Libraries

In [None]:
import sys
sys.path.append("../src/")
import torch
import numpy as np
import matplotlib.pyplot as plt
from models.vanilla_pinn import PINN
from loss_functions.poisson import loss_fn, get_source_function
from tools.tools import print_trainable_parameters, sample_parameters_from_folder, plot_parameter_points
from tools.train_loops import fine_tune
# -- matplotlib styling
plt.style.use("fast")
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('font', serif='lmodern')
plt.rc('font', size=12)  # Adjust the font size if needed

## Load dataset

In [None]:
device = "mps"
dataset_path =  "../dataset/poisson/0"
output_dir = "."
# -- get testing parameters
parameter_test_org, file_test_org  = sample_parameters_from_folder(dataset_path, N=30, idx_start=1, type="test")
# -- to get a shorter range
mask =  torch.tensor(parameter_test_org) < 5.0 #7.5
parameter_test = torch.tensor(parameter_test_org)[mask]
file_test = [f for f, m in zip(file_test_org, mask) if m]
print("[Train] ",file_test)

# --
parameter_train_org, file_train_org  = sample_parameters_from_folder(dataset_path, N=12, idx_start=0, type="train") 
mask =  torch.tensor(parameter_train_org) < 5.0 #7.5
parameter_train = torch.tensor(parameter_train_org)[mask]
file_train = [f for f, m in zip(file_train_org,  mask) if m]
print("[Train] ",file_train)
# -- plot parameters
plot_parameter_points(parameter_train, parameter_test, output_dir, "vanilla_pinn/parameter_variation.png")
# -- load the actual data
dtype = torch.float32
X_train = torch.tensor(np.load(f"{dataset_path}/{file_train[0]}")['x'], dtype=dtype).to(device)
Y_train = torch.tensor(np.load(f"{dataset_path}/{file_train[0]}")['y'], dtype=dtype).to(device)
X_test = torch.tensor(np.load(f"{dataset_path}/{file_test[0]}")['x'], dtype=dtype).to(device)
Y_test = torch.tensor(np.load(f"{dataset_path}/{file_test[0]}")['y'], dtype=dtype).to(device)
solutions = [torch.tensor(np.load(f"{dataset_path}/{f}")['u'], dtype=dtype).to(device) for f in file_test]

### Dataset Test

In [None]:
idx = 4
plt.title(str(file_test[idx]))
plt.contourf(solutions[idx].detach().cpu().numpy(), levels=50, cmap='rainbow')
plt.colorbar()
plt.show()
plt.close()

## Evaluate

In [None]:
torch.manual_seed(3)
# ------------ Inputs -------------------------------
X = X_test.to(device)
Y = Y_test.to(device)
solution_shape = X.shape
# --
x_flat = X.reshape(-1, 1)
y_flat = Y.reshape(-1, 1)
# --
inputs = [x_flat.to(device), y_flat.to(device)]
parameters = parameter_test.to(device) # choose 30 items withing this range
source_function = get_source_function("desai_song")
# ---------- Model Init -----------------------------
model = PINN(n_layers=3, n_neurons=64).to(device)
# ---------- Optimization Parameters ----------------
lr = 1e-3
epochs = 10000

In [None]:
model, metrics =  fine_tune(model, inputs, source_function, loss_fn, parameters, epochs, solutions, solution_shape, lr=lr, output_dir="vanilla_pinn/", print_interval=100)
np.save("vanilla_pinn/metrics.npz", metrics)
