# PINN Initialization and Finetuning

This notebook demonstrates **how to initialize and fine-tune Physics-Informed Neural Networks (PINNs)** using the [HyPINO multi-physics neural operator](https://arxiv.org/abs/2509.05117).

HyPINO maps a given PDE specification — defined by its coefficients, source term, and boundary conditions — to a **set of pretrained network weights** that already approximate the corresponding solution field.  
These weights can then be used to initialize a PINN for **task-specific adaptation or fine-tuning**.

### Workflow

1. **Generate an initial PINN** with weights $\theta^\star = \Phi(L, f, g, h)$ predicted by HyPINO,  
   where $(L, f, g, h)$ denote the PDE operator, source term, and boundary data.  

2. **Evaluate the initialized model** to obtain the zero-shot solution $u_{\theta^\star}(x)$.  

3. **Fine-tune the PINN** by minimizing the residual loss $\mathcal{L}_{\text{PINN}} = \lambda_R \mathcal{L}_R + \lambda_D \mathcal{L}_D + \lambda_N \mathcal{L}_N$.  

In [None]:
import sys, os

project_root = os.path.abspath("..")
sys.path.append(project_root)

In [None]:
import os
import torch
import numpy as np

from src.models import HyPINO
from src.data.utils import to_tensor
from src.data.utils import plot_grids, encode_pde_str

In [None]:
if torch.cuda.is_available():
    device = f'cuda:{torch.cuda.current_device()}'
else:
    device = 'cpu'

## Load model weights

In [None]:
model = HyPINO.load_from_safetensors('../models/hypino.safetensors').to(device).eval()

## Load example PDE
See the `02_inference.ipynb` notebook for examples on how to load or create other PDEs.

In [None]:
inputs_path = '../assets/wave/arrays'

# inputs
dirichlet_mask = np.load(os.path.join(inputs_path, 'dirichlet_mask.npy'))
dirichlet_conditions = np.load(os.path.join(inputs_path, 'dirichlet_conditions.npy'))
neumann_mask = np.load(os.path.join(inputs_path, 'neumann_mask.npy'))
neumann_conditions = np.load(os.path.join(inputs_path, 'neumann_conditions.npy'))
source_function = np.load(os.path.join(inputs_path, 'source_function.npy'))

# if necessary, load domain mask
domain_mask = np.load(os.path.join(inputs_path, 'domain_mask.npy'))

# if necessary, load neumann normals for computing correct boundary losses
if os.path.exists(os.path.join(inputs_path, 'neumann_normals.npy')):
        neumann_normals = np.load(os.path.join(inputs_path, 'neumann_normals.npy'))
else:
        neumann_normals = np.zeros((2, 224, 224))

# if available, load reference solution
reference_solution = np.load(os.path.join(inputs_path, 'reference_solution.npy'))

plot_grids([dirichlet_mask, dirichlet_conditions, neumann_mask, neumann_conditions, 
            neumann_normals[0], neumann_normals[1], source_function, domain_mask, reference_solution], 
           titles=['Dirichlet mask', 'Dirichlet boundary conditions', 'Neumann mask', 
                   'Neumann boundary conditions', 'Neumann normals x', 'Neumann normals y', 
                   'Source function', 'Domain mask', 'Reference solution'])

Create the grid-based inputs to HyPINO:

In [None]:
mat_inputs = np.stack([dirichlet_mask, neumann_mask,
                   dirichlet_conditions, neumann_conditions,
                   source_function], axis=0)
mat_inputs_tensor = to_tensor(mat_inputs)

Create the vector of coefficients for HyPINO:

In [None]:
diff_operator = '0.5 * uyy - 2 * uxx'
pde_coeffs = encode_pde_str(diff_operator)
pde_coeffs_tensor = to_tensor([c for c in pde_coeffs.values()])

Prepare inputs as dictionary:

In [None]:
pde = {
    'pde_coeffs': pde_coeffs_tensor.to(device),
    'mat_inputs': mat_inputs_tensor.to(device),
    'neu_normals': to_tensor(neumann_normals).to(device),
    'pde_str': diff_operator,
    'domain_mask': to_tensor(domain_mask).to(device)
}

## Generate PINN
Generate the target PINN for the given PDE. Optionally, create an ensemble of PINNs, where the ensemble is iteratively expanded by generating and adding a PINN that corrects the residual of the ensemble in the previous iteration. See the `03_iterative_refinement.ibynb` notebook for more examples and explanations.

Set `num_iter=0` to skip iterative refinement and use just the first predicted PINN (ensemble of 1 expert).

In [None]:
pinn_ensemble = model.iterative_refinement(pde, num_iter=5)

## Finetuning
We provide the `finetuning` method for convenience. It takes the PDE and the PINN (or ensemble) and trains it by computing the residual and errors on the Dirichlet and Neumann boundaries. It internally creates a grid of collocation points, of which a random subset of size `num_collocation_points` are used in each iteration to compute the loss (ideally set this value as high as possible before running into OOM errors). 

Training runs for `num_adam_iterations` iterations with the Adam optimizer, before switching to LBFGS for `num_lbfgs_iterations` iterations. 

Further, optional arguments are `loss_weights={'F':0.1,'D':10,'N':5}` to set the weights per loss (residual `F`, Dirichlet `D`, Neumann `N`), `eval_every` to set the interval of evaluations, and `plot_path` to optionally save the plots. 

Note that this method creates a grid of collocation points and uses the Dirichlet and Neumann masks from above to identify which of them lie on a boundary. This approach generally results in far fewer points on the boundaries compared to the interior of the domain. This imbalance can be a limitation, as PINN training often requires a higher density of boundary points to ensure the boundary conditions are well-approximated. Therefore, we provide the `boundary_oversample` argument that can be used to balance the number of collocation points on the boundaries and inside the domain, either by passing the string `'balanced'` or a float between 0 and 1 denoting the proportion.

In [None]:
from src.models.utils import finetuning

hist = finetuning(pde=pde, net=pinn_ensemble, num_adam_iterations=4500, num_lbfgs_iterations=500, 
                  num_collocation_points=4096, eval_every=10, loss_weights={'F': 0.01, 'D': 10, 'N': 5},
                  boundary_oversample='balanced')

In [None]:
torch.save(pinn_ensemble, 'finetuned_wave_ensemble.pth')