# 2D Helmholtz PINN with SIREN (Documented & Cleaned)

This notebook implements a simple **physics-informed neural network (PINN)** to solve the 2D scalar Helmholtz equation for the out-of-plane field \( E_z(x, y) \).  
We parameterize the complex field as two real-valued outputs, \(\Re(E_z)\) and \(\Im(E_z)\), from a **SIREN** network (sine-activated MLP).

**PDE:**  
\[ -\nabla^2 E_z(x,y) - \varepsilon(x,y)\,\omega^2\,E_z(x,y) = i\,\omega\,J_z(x,y). \]

- The source \(J_z\) is approximated as a **single-point** injection at the grid point closest to `source_position`.
- Relative permittivity \(\varepsilon\) can be uniform (free space) or have a circular dielectric inclusion.
- We enforce a simple **Sommerfeld-like absorbing condition** on a thin strip at the domain boundaries:
  \[ \frac{\partial E_z}{\partial n} + i\,\omega\,E_z = 0. \]

**What’s in here:**  
- A documented SIREN implementation (`SineLayer`, `SIREN`).  
- A documented solver class (`HelmholtzSolver`) with:
  - grid creation
  - Laplacian via autodiff
  - PDE residual and boundary loss
  - training loop with optional LR scheduler
  - prediction and visualization
- Three example experiments (free-space, single dielectric, and two alternative configurations).
- Minor fixes/cleanups vs. the original:
  - Clarified parameters & defaults.
  - Fixed a typo: using `params4` (not `params2`) when constructing `solver4`.
  - For clarity, `solver4` demonstrates **one** dielectric (the class supports a single circle).  
    Extending to multiple dielectrics would require a small change to `get_permittivity` and plotting logic.

> Tip for your report: briefly explain why SIREN is a good fit (Fourier-like representation), how the Sommerfeld strip is used as a soft ABC, and how you validate (loss curves + field visuals).


In [None]:

import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from tqdm.auto import trange

# ------------------------------
# Configuration parameters
# ------------------------------
scale = 1.0
params = {
    # --- Domain ---
    'domain_size': 5.0 * scale,     # Square domain side length (units are arbitrary)
    'grid_points': 128,             # Number of samples per axis (creates grid_points^2 collocation points)

    # --- Physics ---
    'omega': 20.0 / scale,          # Angular frequency for Helmholtz
    'source_position': (-0.5 * scale, -1.5 * scale),  # Point source location (approximate)

    # --- Dielectric (single circular inclusion) ---
    'has_dielectric': False,
    'dielectric_center': (1.0 * scale, 1.5 * scale),
    'dielectric_radius': 0.3 * scale,
    'dielectric_eps': 2.0,          # Relative permittivity inside the circle; outside is 1.0

    # --- Network (SIREN) ---
    'hidden_features': 256,
    'hidden_layers': 3,             # Number of hidden Sine layers (not counting the final linear layer)
    'omega_0': 30.0,                # SIREN frequency scaling (first and hidden layers)

    # --- Training ---
    'num_epochs': 600,
    'learning_rate': 2e-5,
    'batch_size': 512,
    'print_every': 50,

    # --- Scheduler ---
    'use_scheduler': True,
    'scheduler_patience': 100,
    'scheduler_factor': 0.5,
    'scheduler_min_lr': 1e-9,
}


class SineLayer(nn.Module):
    """
    A single SIREN layer: Linear -> sine activation with frequency scaling.

    Args:
        in_features (int): input dimensionality
        out_features (int): output dimensionality
        omega_0 (float): frequency scaling for sine activation
        is_first (bool): if True, use the special initialization for the first layer

    Notes:
        Weight initialization follows the SIREN paper:
        - First layer: U(-1/in_features, 1/in_features)
        - Subsequent layers: U(-sqrt(6/in_features)/omega_0, sqrt(6/in_features)/omega_0)
    """
    def __init__(self, in_features, out_features, omega_0, is_first=False):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.linear = nn.Linear(in_features, out_features)

        # SIREN weight initialization
        with torch.no_grad():
            if is_first:
                self.linear.weight.uniform_(-1. / in_features, 1. / in_features)
            else:
                bound = np.sqrt(6 / in_features) / omega_0
                self.linear.weight.uniform_(-bound, bound)

    def forward(self, x):
        # Sine nonlinearity with the omega_0 scaling
        return torch.sin(self.omega_0 * self.linear(x))


class SIREN(nn.Module):
    """
    A SIREN MLP that maps 2D coordinates (x, y) -> (Re(Ez), Im(Ez)).

    Args:
        in_features (int): input dimensionality (2 for x,y)
        hidden_features (int): width of hidden layers
        hidden_layers (int): number of Sine layers after the first
        out_features (int): output dimensionality (2 for real, imag)
        omega_0 (float): frequency scaling used in Sine layers
    """
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, omega_0):
        super().__init__()
        layers = []

        # First Sine layer
        layers.append(SineLayer(in_features, hidden_features, omega_0, is_first=True))

        # Hidden Sine layers
        for _ in range(hidden_layers):
            layers.append(SineLayer(hidden_features, hidden_features, omega_0, is_first=False))

        # Final linear layer (no sine)
        final_layer = nn.Linear(hidden_features, out_features)
        with torch.no_grad():
            bound = np.sqrt(6 / hidden_features) / omega_0
            final_layer.weight.uniform_(-bound, bound)
        layers.append(final_layer)

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class HelmholtzSolver:
    """
    Physics-Informed solver for the 2D scalar Helmholtz equation using a SIREN backbone.

    Solves for E_z such that:
        -ΔE_z - ε(x,y) * ω^2 * E_z = i * ω * J_z

    The field is represented by two outputs of the network: (Re(E_z), Im(E_z)).
    The PDE residual and boundary condition terms form the training loss.

    Args:
        param (dict): configuration dictionary (see `params` above)

    Attributes:
        model (nn.Module): SIREN model
        device (torch.device): CPU or CUDA
        grid_points_flat (Tensor): (N,2) input coordinates over the domain
    """
    def __init__(self, param):
        # --- Extract parameters ---
        self.domain_size = float(param['domain_size'])
        self.grid_points = int(param['grid_points'])
        self.omega = float(param['omega'])
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # --- Build SIREN model ---
        self.model = SIREN(
            in_features=2,
            hidden_features=param['hidden_features'],
            hidden_layers=param['hidden_layers'],
            out_features=2,               # (real, imag)
            omega_0=param['omega_0'],
        ).to(self.device)

        # --- Source ---
        self.source_position = tuple(param['source_position'])

        # --- Dielectric (single circle) ---
        self.has_dielectric = bool(param['has_dielectric'])
        self.dielectric_center = torch.tensor(param['dielectric_center'], dtype=torch.float32, device=self.device)
        self.dielectric_radius = float(param['dielectric_radius'])
        self.dielectric_eps = float(param['dielectric_eps'])

        # --- Optimizer & Scheduler ---
        self.learning_rate = float(param['learning_rate'])
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        self.use_scheduler = bool(param['use_scheduler'])
        if self.use_scheduler:
            self.scheduler = lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min',
                factor=param['scheduler_factor'],
                patience=param['scheduler_patience'],
                min_lr=param['scheduler_min_lr']
            )

        self.batch_size = int(param['batch_size'])
        self.print_every = int(param['print_every'])

        # Build grid of collocation points
        self.create_grid()

    def create_grid(self):
        """
        Create a uniform square grid of size (grid_points x grid_points) spanning
        [-domain_size/2, +domain_size/2] in both x and y.
        Stores:
            - self.xx, self.yy: meshgrids (numpy)
            - self.grid_points_flat: Tensor of shape (N, 2) on device
        """
        x = np.linspace(-self.domain_size / 2, self.domain_size / 2, self.grid_points)
        y = np.linspace(-self.domain_size / 2, self.domain_size / 2, self.grid_points)
        self.xx, self.yy = np.meshgrid(x, y)

        xy = np.stack([self.xx.flatten(), self.yy.flatten()], axis=1).astype(np.float32)
        self.grid_points_flat = torch.tensor(xy, device=self.device)

    def add_dielectric_circle(self, center, radius, eps):
        """
        Enable a single circular dielectric region.

        Args:
            center (tuple): (x0, y0)
            radius (float): circle radius
            eps (float): relative permittivity inside the circle
        """
        self.has_dielectric = True
        self.dielectric_center = torch.tensor(center, dtype=torch.float32, device=self.device)
        self.dielectric_radius = float(radius)
        self.dielectric_eps = float(eps)

    def get_permittivity(self, x):
        """
        Piecewise-constant ε(x,y): 1.0 everywhere, and `dielectric_eps` inside the circle if enabled.

        Args:
            x (Tensor): shape (N,2) coordinates
        Returns:
            Tensor: shape (N,1) of relative permittivity values
        """
        eps = torch.ones(x.shape[0], 1, device=self.device)
        if self.has_dielectric:
            dist = torch.linalg.norm(x - self.dielectric_center, dim=1, keepdim=True)
            eps = torch.where(dist < self.dielectric_radius,
                              self.dielectric_eps * torch.ones_like(eps),
                              eps)
        return eps

    def get_source(self, x):
        """
        Approximate a point source Jz at the nearest grid point to `source_position`.
        Returns:
            Tensor: shape (N, 2) -> (Jz_real, Jz_imag). We inject unit amplitude into the real part.
        """
        N = x.shape[0]
        jz_real = torch.zeros(N, 1, device=self.device)
        jz_imag = torch.zeros(N, 1, device=self.device)

        src = torch.tensor(self.source_position, dtype=torch.float32, device=self.device)
        # Index of the collocation point closest to the desired source position
        closest_idx = torch.argmin(torch.sum((x - src)**2, dim=1))
        jz_real[closest_idx] = 1.0  # real source; imag remains zero

        return torch.hstack((jz_real, jz_imag))

    def compute_laplacian(self, x):
        """
        Compute ∇² of (Re, Im) outputs via autograd by summing second derivatives.
        Args:
            x (Tensor): shape (N,2)
        Returns:
            Tensor: shape (N,2) containing (laplacian_real, laplacian_imag)
        """
        x = x.requires_grad_(True)
        y_pred = self.model(x)                   # (N,2)
        y_real = y_pred[:, 0:1]
        y_imag = y_pred[:, 1:2]

        # First gradients
        grad_y_real = torch.autograd.grad(y_real, x, grad_outputs=torch.ones_like(y_real), create_graph=True)[0]
        grad_y_imag = torch.autograd.grad(y_imag, x, grad_outputs=torch.ones_like(y_imag), create_graph=True)[0]

        # Sum of second partials for Laplacian
        laplacian_real = 0.0
        laplacian_imag = 0.0
        for i in range(x.shape[1]):
            # ∂/∂x_i of grad component i
            g2_real = torch.autograd.grad(grad_y_real[:, i:i+1], x,
                                          grad_outputs=torch.ones_like(grad_y_real[:, i:i+1]), create_graph=True)[0][:, i:i+1]
            g2_imag = torch.autograd.grad(grad_y_imag[:, i:i+1], x,
                                          grad_outputs=torch.ones_like(grad_y_imag[:, i:i+1]), create_graph=True)[0][:, i:i+1]
            laplacian_real = laplacian_real + g2_real
            laplacian_imag = laplacian_imag + g2_imag

        return torch.hstack((laplacian_real, laplacian_imag))

    def helmholtz_residual(self, x):
        """
        Compute the PDE residuals for (real, imag) parts:
            R_real = -ΔRe(Ez) - ε ω^2 Re(Ez) + ω * Im(Jz)
            R_imag = -ΔIm(Ez) - ε ω^2 Im(Ez) - ω * Re(Jz)
        Returns:
            Tensor: shape (N,2) residuals
        """
        y_pred = self.model(x)
        y_real, y_imag = y_pred[:, 0:1], y_pred[:, 1:2]

        laplacian = self.compute_laplacian(x)
        laplacian_real, laplacian_imag = laplacian[:, 0:1], laplacian[:, 1:2]

        eps = self.get_permittivity(x)
        jz = self.get_source(x)
        jz_real, jz_imag = jz[:, 0:1], jz[:, 1:2]

        residual_real = -laplacian_real - eps * (self.omega ** 2) * y_real + self.omega * jz_imag
        residual_imag = -laplacian_imag - eps * (self.omega ** 2) * y_imag - self.omega * jz_real
        return torch.cat([residual_real, residual_imag], dim=1)

    def square_bc_loss(self, x):
        """
        Sommerfeld-like absorbing boundary condition on a thin strip near the edges:
            ∂E/∂n + i ω E = 0  (applied in L2 sense on left/right/top/bottom strips)

        Args:
            x (Tensor): collocation points (N,2)
        Returns:
            Tensor: scalar loss
        """
        # Thickness of the boundary strip (10% of half-domain)
        boundary_width = 0.1 * (self.domain_size / 2.0)

        # Distances to each boundary
        dist_left   = torch.abs(x[:, 0:1] + self.domain_size / 2)
        dist_right  = torch.abs(x[:, 0:1] - self.domain_size / 2)
        dist_bottom = torch.abs(x[:, 1:2] + self.domain_size / 2)
        dist_top    = torch.abs(x[:, 1:2] - self.domain_size / 2)

        # Binary masks for points inside the strip
        left_mask   = (dist_left   < boundary_width).float()
        right_mask  = (dist_right  < boundary_width).float()
        bottom_mask = (dist_bottom < boundary_width).float()
        top_mask    = (dist_top    < boundary_width).float()

        # Network outputs and gradients
        y_pred = self.model(x)
        y_real, y_imag = y_pred[:, 0:1], y_pred[:, 1:2]

        x = x.requires_grad_(True)
        grad_y_real = torch.autograd.grad(y_real, x, grad_outputs=torch.ones_like(y_real), create_graph=True)[0]
        grad_y_imag = torch.autograd.grad(y_imag, x, grad_outputs=torch.ones_like(y_imag), create_graph=True)[0]

        # Helpers to compute (∂/∂n Re, ∂/∂n Im) with outward normals per side.
        def side_loss(mask, nx, ny):
            normal = torch.cat([nx, ny], dim=1)  # (N,2)
            dn_real = torch.sum(grad_y_real * normal, dim=1, keepdim=True)
            dn_imag = torch.sum(grad_y_imag * normal, dim=1, keepdim=True)
            # ‖(∂Re/∂n + ω Im)^2 + (∂Im/∂n - ω Re)^2‖ weighted by mask
            return mask * ((dn_real + self.omega * y_imag) ** 2 + (dn_imag - self.omega * y_real) ** 2)

        # Left: normal = [-1, 0]
        loss_left = side_loss(left_mask,  -left_mask, torch.zeros_like(left_mask))
        # Right: normal = [ 1, 0]
        loss_right = side_loss(right_mask,  right_mask, torch.zeros_like(right_mask))
        # Bottom: normal = [ 0,-1]
        loss_bottom = side_loss(bottom_mask, torch.zeros_like(bottom_mask), -bottom_mask)
        # Top: normal = [ 0, 1]
        loss_top = side_loss(top_mask,    torch.zeros_like(top_mask),      top_mask)

        bc_loss = loss_left + loss_right + loss_bottom + loss_top
        return bc_loss.mean()

    def train(self, num_epochs, print_every):
        """
        Train the PINN to minimize (PDE residual + boundary loss).

        Args:
            num_epochs (int): number of epochs
            print_every (int): log interval
        Returns:
            list of floats: per-epoch total losses
        """
        self.model.train()
        losses, physical_losses, bc_losses, lr_history = [], [], [], []

        x_train = self.grid_points_flat
        batch_size = min(self.batch_size, x_train.shape[0])
        num_batches = (x_train.shape[0] + batch_size - 1) // batch_size

        for epoch in trange(num_epochs, desc="Training PINN"):
            epoch_loss = 0.0
            epoch_phys = 0.0
            epoch_bc = 0.0

            # Shuffle points each epoch
            idx = torch.randperm(x_train.shape[0], device=self.device)
            x_shuffled = x_train[idx]

            for b in range(num_batches):
                self.optimizer.zero_grad()
                s = b * batch_size
                e = min((b + 1) * batch_size, x_train.shape[0])
                xb = x_shuffled[s:e]

                residuals = self.helmholtz_residual(xb)
                physics_loss = torch.mean(residuals ** 2)
                bc_loss = self.square_bc_loss(xb)

                loss = physics_loss + bc_loss
                loss.backward()
                self.optimizer.step()

                # Accumulate weighted by batch size
                weight = (e - s)
                epoch_loss += loss.item() * weight
                epoch_phys += physics_loss.item() * weight
                epoch_bc   += bc_loss.item() * weight

            # Averages
            N = float(x_train.shape[0])
            epoch_loss /= N
            epoch_phys /= N
            epoch_bc   /= N

            if self.use_scheduler:
                self.scheduler.step(epoch_loss)

            current_lr = self.optimizer.param_groups[0]['lr']
            lr_history.append(current_lr)

            losses.append(epoch_loss)
            physical_losses.append(epoch_phys)
            bc_losses.append(epoch_bc)

            if (epoch + 1) % print_every == 0:
                print(f"Epoch {epoch + 1:4d} | loss={epoch_loss:.4e} | physics={epoch_phys:.4e} | bc={epoch_bc:.4e}")
                if len(lr_history) > 1 and lr_history[-1] != lr_history[-2]:
                    print(f"  -> LR decayed to {current_lr:.3e}")

        # stash for plotting
        self.losses = losses
        self.physical_losses = physical_losses
        self.bc_losses = bc_losses
        self.lr_history = lr_history
        return losses

    def plot_losses(self):
        """Plot total, physics, and boundary losses (and LR if scheduler is used)."""
        if not hasattr(self, 'losses'):
            print("Train first, then plot_losses().")
            return

        fig, ax1 = plt.subplots(figsize=(10, 6))
        epochs = np.arange(1, len(self.losses) + 1)

        ax1.plot(epochs, self.losses, label="Total Loss", lw=2)
        ax1.plot(epochs, self.physical_losses, label="Physics Loss", lw=1.5)
        ax1.plot(epochs, self.bc_losses, label="Boundary Loss", lw=1.5)

        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss")
        ax1.set_title("Loss Curves")
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
        ax1.legend(loc="upper right")

        if hasattr(self, 'lr_history') and self.lr_history and self.use_scheduler:
            ax2 = ax1.twinx()
            ax2.plot(epochs, self.lr_history, label="Learning Rate", linestyle="--")
            ax2.set_ylabel("Learning Rate")
            ax2.tick_params(axis='y')
            ax2.legend(loc="lower right")

        plt.tight_layout()
        plt.show()

    def predict(self):
        """Run the forward model on the full grid and cache Re/Im/|E| for plotting."""
        self.model.eval()
        with torch.no_grad():
            out = self.model(self.grid_points_flat)  # (N,2)
            real = out[:, 0].reshape(self.grid_points, self.grid_points).cpu().numpy()
            imag = out[:, 1].reshape(self.grid_points, self.grid_points).cpu().numpy()
            mag = np.hypot(real, imag)

        self.pred_real = real
        self.pred_imag = imag
        self.pred_magnitude = mag

    def visualize(self):
        """
        Show Imag(Ez) and |E| heatmaps.
        If a dielectric is present, draw its circle.
        """
        if not hasattr(self, 'pred_real'):
            print("Call predict() before visualize().")
            return

        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        extent = [-self.domain_size/2, self.domain_size/2, -self.domain_size/2, self.domain_size/2]

        im1 = axes[0].imshow(self.pred_imag, extent=extent, cmap='RdBu', origin='lower')
        axes[0].set_title('Imag(E_z)'); plt.colorbar(im1, ax=axes[0])

        im2 = axes[1].imshow(self.pred_magnitude, extent=extent, cmap='viridis', origin='lower')
        axes[1].set_title('|E_z|'); plt.colorbar(im2, ax=axes[1])

        # Draw dielectric (single circle) if present
        if self.has_dielectric:
            for ax in axes:
                circ = Circle(self.dielectric_center.cpu().numpy(),
                              self.dielectric_radius, fill=False, color='black', linewidth=1.5)
                ax.add_patch(circ)

        for ax in axes:
            ax.set_xlabel('x'); ax.set_ylabel('y')

        plt.tight_layout()
        plt.show()


# ------------------------------
# Experiment 1: Free-space (no dielectric)
# ------------------------------
solver = HelmholtzSolver(params)
if params['has_dielectric']:
    solver.add_dielectric_circle(
        center=params['dielectric_center'],
        radius=params['dielectric_radius'],
        eps=params['dielectric_eps'],
    )
solver.train(num_epochs=params['num_epochs'], print_every=params['print_every'])
solver.plot_losses(); solver.predict(); solver.visualize()


# ------------------------------
# Experiment 2: Single dielectric at given center
# ------------------------------
params2 = params.copy()
params2['has_dielectric'] = True
solver2 = HelmholtzSolver(params2)
solver2.add_dielectric_circle(
    center=params2['dielectric_center'],
    radius=params2['dielectric_radius'],
    eps=params2['dielectric_eps'],
)
solver2.train(num_epochs=params2['num_epochs'], print_every=params2['print_every'])
solver2.plot_losses(); solver2.predict(); solver2.visualize()


# ------------------------------
# Experiment 3: Alternative geometry/positions
# ------------------------------
params3 = {
    'domain_size': 5.0 * scale,
    'grid_points': 128,
    'omega': 20.0 / scale,
    'source_position': (-1.0 * scale, -1.5 * scale),
    'has_dielectric': True,
    'dielectric_center': (0.0 * scale, 0.0 * scale),
    'dielectric_radius': 1.0 * scale,
    'dielectric_eps': 2.0,
    'hidden_features': 256,
    'hidden_layers': 3,
    'omega_0': 30.0,
    'num_epochs': 600,
    'learning_rate': 2e-5,
    'batch_size': 512,
    'print_every': 50,
    'use_scheduler': True,
    'scheduler_patience': 100,
    'scheduler_factor': 0.5,
    'scheduler_min_lr': 1e-9,
}
solver3 = HelmholtzSolver(params3)
solver3.add_dielectric_circle(
    center=params3['dielectric_center'],
    radius=params3['dielectric_radius'],
    eps=params3['dielectric_eps'],
)
solver3.train(num_epochs=params3['num_epochs'], print_every=params3['print_every'])
solver3.plot_losses(); solver3.predict(); solver3.visualize()


# ------------------------------
# Experiment 4: Another configuration (cleaned: fix undefined params2)
# Note: Demonstrates *single* dielectric (class supports one circle).
# ------------------------------
params4 = {
    'domain_size': 5.0 * scale,
    'grid_points': 128,
    'omega': 20.0 / scale,
    'source_position': (-1.0 * scale, -0.5 * scale),
    'has_dielectric': True,
    'dielectric_center': (0.5 * scale, 1.0 * scale),
    'dielectric_radius': 0.5 * scale,
    'dielectric_eps': 2.0,
    'hidden_features': 256,
    'hidden_layers': 3,
    'omega_0': 30.0,
    'num_epochs': 600,
    'learning_rate': 2e-5,
    'batch_size': 512,
    'print_every': 50,
    'use_scheduler': True,
    'scheduler_patience': 100,
    'scheduler_factor': 0.5,
    'scheduler_min_lr': 1e-9,
}

solver4 = HelmholtzSolver(params4)          # (fixed) use params4 here
solver4.add_dielectric_circle(              # single dielectric (class supports one)
    center=params4['dielectric_center'],
    radius=params4['dielectric_radius'],
    eps=params4['dielectric_eps'],
)
solver4.train(num_epochs=params4['num_epochs'], print_every=params4['print_every'])
solver4.plot_losses(); solver4.predict(); solver4.visualize()
