In [None]:
import os
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
google_drive_path = ''
os.chdir(google_drive_path)
print("Current working directory:", os.getcwd())
!ls

!python -m pip install cripser==0.0.15

In [None]:
import os
import sys
import csv
import time
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

from torchsummary import summary
from scipy.interpolate import griddata
from tqdm import tqdm, trange
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tabulate import tabulate
import cripser

# To allow me to import the functions from other folders from the parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))


In [None]:
from Functions.utils import *
from Functions.Point_Sampling.point_sampler import *
from Functions.logging.BRIDGE_logging import *
from Functions.Computations.PH2D import *
from Functions.Point_Sampling.BS2D import *
from Functions.Training.enforce_Const_2D import *
from Functions.Plotting_functions.training_curves.BRIDGE_curves import *
from Functions.Plotting_functions.BRIDGE_results import *
from Functions.Training.ALM import *
from Functions.Computations.eval2D import *


from File_Paths.file_paths import interfaces_path, mesh_path
from Functions.Point_Sampling.point_sampler import Point_Sampler
from Models.GINN_Models.GINN import GINN


from Test_Cases.Bridge_around_object.BRIDGE_Master_object import BRIDGE_Master_Object

BRIDGE = BRIDGE_Master_Object(Normalize=True,Symmetry=False)
BRIDGE.create_interfaces()

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
class SDF_GINN(torch.nn.Module):
    def __init__(self,
                 hparams_model,
                 hparams_feature_expansion):

        super().__init__()
        self.model = GINN(BRIDGE,
                          feature_expansion = hparams_feature_expansion ,
                          Model_hyperparameters = hparams_model)

    def forward(self,coords):
        SDF = self.model(coords)
        return SDF


class density_GINN(torch.nn.Module):
    def __init__(self,
                 hparams_model,
                 hparams_feature_expansion,
                 density_alpha,
                 volume_ratio):

        super().__init__()
        self.volume_ratio = volume_ratio
        self.density_alpha = density_alpha
        self.model = GINN(BRIDGE,
                          feature_expansion = hparams_feature_expansion ,
                          Model_hyperparameters = hparams_model)

    def forward(self,coords):
        SDF = self.model(coords)
        v = torch.as_tensor(self.volume_ratio, dtype=SDF.dtype, device=SDF.device)
        offset = torch.log(v / (1.0 - v))
        rho = torch.sigmoid(self.density_alpha * SDF + offset)
        return rho.clamp(0.0, 1.0)


class Geometry_Net(torch.nn.Module):
    def __init__(self,
                hparams_model,
                hparams_feature_expansion,
                density_alpha,
                volume_ratio):

        super().__init__()
        self.volume_ratio = volume_ratio
        self.density_alpha = density_alpha
        self.model = GINN(BRIDGE,
                        feature_expansion = hparams_feature_expansion ,
                        Model_hyperparameters = hparams_model)

    def forward(self,coords):
        SDF = self.model(coords)
        v = torch.as_tensor(self.volume_ratio, dtype=SDF.dtype, device=SDF.device)
        offset = torch.log(v / (1.0 - v))
        rho = torch.sigmoid(-self.density_alpha * SDF + offset)
        #rho = torch.sigmoid(-10 * SDF)
        rho = rho.clamp(1e-6, 1.0)
        return rho, SDF


class PINN(torch.nn.Module):
    def __init__(self,
                 hparams_model,
                 hparams_feature_expansion,
                 mollifier_alpha):

        super().__init__()
        self.mollifier_alpha = mollifier_alpha

        self.model = GINN(BRIDGE,
                          feature_expansion = hparams_feature_expansion ,
                          Model_hyperparameters = hparams_model)

    @staticmethod
    def enforce_dirichlet_BC(alpha,u,coords):

        edge_vertices = BRIDGE.edge_vertices

        device = coords.device if isinstance(coords, torch.Tensor) else None
        if isinstance(edge_vertices, torch.Tensor):
            x_left  = edge_vertices[0, 0].item()
            x_right = edge_vertices[2, 0].item()
        else:
            x_left  = edge_vertices[0, 0]
            x_right = edge_vertices[2, 0]

        distances_left = torch.abs(coords[:, 0] - x_left)
        distances_right = torch.abs(coords[:, 0] - x_right)
        multiplier_left = torch.tanh(alpha*distances_left)
        multiplier_right = torch.tanh(alpha*distances_right)
        multiplier = multiplier_left * multiplier_right
        if device is not None:
                multiplier = multiplier.to(device)

        return u*multiplier.unsqueeze(1)

    def forward(self,coords):
        u = self.model(coords)
        return self.enforce_dirichlet_BC(self.mollifier_alpha,u,coords)

In [None]:
class GINN_losses(Properties):
    def __init__(self,
                 GINN_model,
                 test_case,
                 GINN_hparams,
                 PH,
                 boundary_sampler,
                 enforce_density):

        super().__init__(test_case=test_case)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
        self.GINN_model = GINN_model.to(self.device)
        self.GINN_hparams = GINN_hparams
        self.boundary_sampler = boundary_sampler
        self.test_case = test_case
        self.enforce_density = enforce_density
        self.dim = test_case.dim # 2D or 3D
        self.envelope_extension_factor = self.GINN_hparams['envelope_extension_factor']
        self.num_points_connectivity_loss = self.GINN_hparams['num_points_connectivity_loss']
        self.num_points_interface_loss = self.GINN_hparams['num_points_interface_loss']
        self.num_points_normals_loss = self.GINN_hparams['num_points_normals_loss']
        self.num_points_envelope_loss = self.GINN_hparams['num_points_envelope_loss']
        self.clip_max_value = self.GINN_hparams['clip_max_value']
        self.clip_min_value = self.GINN_hparams['clip_min_value']
        self.max_curv = self.GINN_hparams['max_curv']
        self.curv_start_epoch = self.GINN_hparams['curv_start_epoch']
        self.PH = PH


    def surface_normal_loss(self):

        #1. Import prescribed surface normals for the test case and the points on the relevant surfaces
        normals = self.test_case.interfaces.get_all_prescribed_surface_normals(num_points=self.num_points_normals_loss,
                                                                               type='torch_tensor',
                                                                               include_all=True)

        points = normals['points'].to(self.device)


        neumann_points          = points[normals['neumann_idx']]
        dirichlet_points        = points[normals['dirichlet_idx']]

        input = torch.vstack([neumann_points, dirichlet_points]).to(self.device)
        input = input.requires_grad_(True)
        target_surface_normals = torch.vstack([normals['neumann_normals'],
                                               normals['dirichlet_normals']]).to(self.device)

        #2. Predict the surface normals using the GINN model
        SDF_values = self.GINN_model(input).view(-1)
        predicted_surface_normals = torch.autograd.grad(inputs=input,
                                            outputs=SDF_values,
                                            grad_outputs=torch.ones_like(SDF_values),
                                            create_graph=True,
                                            only_inputs=True)[0]

        # 3. Normalize the predicted surface normals
        predicted_surface_normals = F.normalize(predicted_surface_normals, p=2 ,dim=1)

        #3. Compute the loss as the mean squared error between predicted and target surface normals
        loss = F.mse_loss(predicted_surface_normals, target_surface_normals)

        #4. Return the loss value
        return loss

    def prohibited_region_loss(self, coords):

        # 1. Identify the points inside the interface regions
        inside_points, inside_mask = self.test_case.interfaces.is_inside_prohibited_region(coords)

        # 2. Identify the SDF values at the inside points
        SDF = self.GINN_model(coords).squeeze()
        SDF_inside = SDF[inside_mask]
        SDF_inside = SDF_inside.view(-1)

        # Build a 1-D boolean mask of the violations
        violation_mask = (SDF_inside < 0)

        # 3. Filter out the points that have SDF > 0 --> Points are considered to be outside the surface
        point_violations = inside_points[violation_mask]
        SDF_violations = SDF_inside[violation_mask]

        if SDF_violations.numel() > 0:
            loss = torch.square(SDF_violations).sum()
        else:
            loss = torch.tensor(0.0, device=SDF.device, dtype=torch.float32,requires_grad=True)  # No points violate the interface constraint

        return loss


    def prescribed_thickness_loss(self, coords):

        # 1. Identify the points inside the prescribed thickness region
        inside_points, inside_mask = self.test_case.interfaces.is_inside_interface_thickness(coords)

        # 2. Collect the SDF values at the inside points
        SDF = self.GINN_model(coords).squeeze()
        SDF_inside = SDF[inside_mask]
        SDF_inside = SDF_inside.view(-1)

        # Build a 1-D boolean mask of the violations
        violation_mask = (SDF_inside > 0)

        # 3. Filter out point that have SDF > 0 --> Point are considered to be outside the surface
        point_violations = inside_points[violation_mask]
        point_violations.requires_grad_(True)
        SDF_violations = SDF_inside[violation_mask]

        # 4. Calculate the loss
        if SDF_violations.numel() > 0:
            prescribed_thickness_loss = torch.square(SDF_violations).sum()
        else:
            prescribed_thickness_loss = torch.tensor(0.0, device=SDF.device, dtype=torch.float32,requires_grad=True)  # No points violate the prescribed thickness constraint

        return prescribed_thickness_loss

    def interface_loss(self):

        #1. Sample points in the interface region
        coords = self.test_case.interfaces.sample_points_from_all_interfaces(self.num_points_interface_loss,
                                                                            random_seed = None,
                                                                            output_type = 'torch_tensor')

        #2. Calulate the SDF values at the sampled points
        input = coords.detach().to(self.device)
        sdf_values = self.GINN_model(input).view(-1,1)
        target_sdf_values = torch.zeros_like(sdf_values)

        #3. Calculate the interface loss
        interface_loss = F.mse_loss(sdf_values, target_sdf_values)
        return interface_loss

    def eikonal_loss(self,coords):

        #1. Compute the gradient of the SDF with respect to the input coordinates
        coords.requires_grad_(True)
        SDF = self.GINN_model(coords).squeeze()
        SDF_grad = torch.autograd.grad(outputs=SDF, inputs=coords, grad_outputs=torch.ones_like(SDF), create_graph=True)[0]

        #2. Compute the norm of the gradient - this is the magnitude of the gradient vector
        SDF_grad_norm = torch.norm(SDF_grad, dim=1)

        #3. Compute the Eikonal loss
        eikonal_loss = torch.mean((SDF_grad_norm - 1) ** 2)

        return eikonal_loss


    def design_envelope_loss(self):

        # 1. Extract the domain from the test case
        design_envelope = self.test_case.domain
        x_min_domain = design_envelope[0]
        x_max_domain = design_envelope[1]
        y_min_domain = design_envelope[2]
        y_max_domain = design_envelope[3]


        # 2. Define the extended domain region
        extension_factor = self.GINN_hparams['envelope_extension_factor']
        x_min_extended = x_min_domain - extension_factor * (x_max_domain - x_min_domain)
        x_max_extended = x_max_domain + extension_factor * (x_max_domain - x_min_domain)
        y_min_extended = y_min_domain - extension_factor * (y_max_domain - y_min_domain)
        y_max_extended = y_max_domain + extension_factor * (y_max_domain - y_min_domain)

        extended_domain = np.array([x_min_extended, x_max_extended, y_min_extended, y_max_extended])


        # 3. Sample points in the extended domain
        point_sampler = Point_Sampler(extended_domain,
                                      num_points_domain=self.num_points_envelope_loss)
        points = next(point_sampler)
        points = points.to(self.device)
        points.requires_grad_(True)

        # 4. Create a mask to remove the points that lie within the design envelope
        condition1 = torch.logical_and(points[:, 0] >= x_min_domain, points[:, 0] <= x_max_domain)
        condition2 = torch.logical_and(points[:, 1] >= y_min_domain, points[:, 1] <= y_max_domain)
        mask = torch.logical_and(condition1, condition2)

        # 5. Filter the points that lie outside the design envelope
        points_outside_envelope = points[~mask]

        # Return an error if no points are sampled outside the design envelope
        if points_outside_envelope.shape[0] == 0:
            raise ValueError("No points sampled outside the design envelope. Please increase the number of points or the extension factor.")

        # 6. Pass the points through the GINN model to compute the SDF values
        sdf_values = self.GINN_model(points_outside_envelope).view(-1)  # Ensure the SDF values are a 1-D tensor


        # 7. Filter the SDF values that are less or equal to zero
        violation_mask = (sdf_values <= 0)        # shape: (M,)
        SDF_constraint_violations = sdf_values[violation_mask]
        if SDF_constraint_violations.numel() > 0:
            points_constraint_violations = points_outside_envelope[violation_mask]
            points_constraint_violations.requires_grad_(True)
            envelope_loss = torch.square(SDF_constraint_violations).sum()
        else:
            envelope_loss = torch.tensor(0.0, device=sdf_values.device, dtype=torch.float32,requires_grad=True )    # No points violate the design envelope constraint
        return envelope_loss


    def design_envelope_loss_from_density(self,
                                          density_model: torch.nn.Module,
                                          iso_level: float = 0.5) -> torch.Tensor:
        """
        Density-based analogue of design_envelope_loss.
        """
        rho_iso = iso_level

        # 1. Extract envelope bounds
        design_envelope = self.test_case.domain
        x_min_domain = design_envelope[0]
        x_max_domain = design_envelope[1]
        y_min_domain = design_envelope[2]
        y_max_domain = design_envelope[3]

        # 2. Extended domain
        extension_factor = self.GINN_hparams['envelope_extension_factor']
        x_min_extended = x_min_domain - extension_factor * (x_max_domain - x_min_domain)
        x_max_extended = x_max_domain + extension_factor * (x_max_domain - x_min_domain)
        y_min_extended = y_min_domain - extension_factor * (y_max_domain - y_min_domain)
        y_max_extended = y_max_domain + extension_factor * (y_max_domain - y_min_domain)

        extended_domain = np.array([x_min_extended, x_max_extended,
                                    y_min_extended, y_max_extended])

        # 3. Sample points in extended domain
        point_sampler = Point_Sampler(
            extended_domain,
            num_points_domain=self.num_points_envelope_loss
        )
        points = next(point_sampler).to(self.device)
        points.requires_grad_(True)

        # 4. Mask out points inside the design envelope
        condition1 = torch.logical_and(points[:, 0] >= x_min_domain,
                                       points[:, 0] <= x_max_domain)
        condition2 = torch.logical_and(points[:, 1] >= y_min_domain,
                                       points[:, 1] <= y_max_domain)
        mask_inside = torch.logical_and(condition1, condition2)

        points_outside_envelope = points[~mask_inside]

        if points_outside_envelope.shape[0] == 0:
            raise ValueError(
                "No points sampled outside the design envelope. "
                "Please increase the number of points or the extension factor."
            )

        # 5. Evaluate density at outside points
        rho_values = density_model(points_outside_envelope).view(-1)

        # 6. Penalize material outside the envelope
        violation_mask = (rho_values > rho_iso)
        rho_viol = rho_values[violation_mask]

        if rho_viol.numel() > 0:
            envelope_loss = torch.square(rho_viol).sum()
        else:
            envelope_loss = torch.tensor(
                0.0,
                device=rho_values.device,
                dtype=torch.float32,
                requires_grad=True
            )

        return envelope_loss


    def smoothness_loss(self,epoch) -> torch.Tensor:

        if epoch < self.curv_start_epoch:
             return torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=True)


        surface_points, weights = self.boundary_sampler.get_surface_pts()
        if surface_points is None:
            print('Returning Zero Loss - No Surface Points Found')
            return torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=True)

        if surface_points.numel() == 0:
            print('Returning Zero Loss - No Surface Points Found')
            return torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=True)

        pts = surface_points.to(device)
        if pts.dim() != 2:
            raise ValueError(f"surface_points must be [B, D], got {pts.shape}")

        w = weights.to(device)
        if w.dim() == 2 and w.size(-1) == 1:
            w = w.squeeze(-1)
        if w.dim() != 1 or w.shape[0] != pts.shape[0]:
            raise ValueError(f"weights must be [B] or [B,1] matching points. Got {weights.shape}")

        B, D = pts.shape

        # First-order derivatives
        pts = pts.clone().detach().requires_grad_(True)
        sdf = self.GINN_model(pts).view(-1)
        grad_outputs = torch.ones_like(sdf)

        df_dx = torch.autograd.grad(
            outputs=sdf,
            inputs=pts,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]                                            # [B, D]

        # --- Second-order derivatives (Hessian wrt x) ---
        H_rows = []
        for d in range(D):
            g_comp = df_dx[:, d]                        # [B]
            Hg = torch.autograd.grad(
                outputs=g_comp,
                inputs=pts,
                grad_outputs=torch.ones_like(g_comp),
                create_graph=True,
                retain_graph=True,
                only_inputs=True,
            )[0]                                        # [B, D]
            H_rows.append(Hg.unsqueeze(1))              # [B, 1, D]
        H = torch.cat(H_rows, dim=1)                    # [B, D, D]

        # Gaussian curvature 
        grad_sq = (df_dx.square()).sum(dim=1)           # [B]
        F4 = torch.clamp(grad_sq.square(), min=1.0e-15)

        # Build augmented matrix [[H, df_dx], [df_dx^T, 0]] per batch 
        top = torch.cat([H, df_dx.unsqueeze(2)], dim=2)                     # [B, D, D+1]
        bottom = torch.cat([df_dx.unsqueeze(1), torch.zeros(B, 1, 1, device=device, dtype=H.dtype)], dim=2)
        aug = torch.cat([top, bottom], dim=1)                                # [B, D+1, D+1]

        det_aug = torch.det(aug)                                             # [B]
        gauss_curvatures = (-1.0) / F4 * det_aug                             # [B]

        #  Mean curvature
        FHFT = torch.einsum('bi,bij,bj->b', df_dx, H, df_dx)                 # [B]
        trH = torch.einsum('bii->b', H)                                      # [B]
        N = torch.clamp(grad_sq.sqrt(), min=1.0e-5)
        mean_curvatures = -(FHFT - (N.pow(2) * trH)) / (2.0 * N.pow(3))      # [B]

        #  E-strain & clipping
        E_strain = (2.0 * mean_curvatures).pow(2) - 2.0 * gauss_curvatures   # [B]
        E_strain = torch.clamp(E_strain, min=self.clip_min_value, max=self.clip_max_value)

        # Weight & sum
        E_strain = E_strain * w
        total = E_strain.sum()                                               # scalar

        # Hinge with max_curv
        zero = torch.tensor(0.0, device=device, dtype=torch.float32)
        loss = torch.maximum(zero, total - float(self.max_curv))

        return loss


    def connectivity_loss(self) -> torch.Tensor:
        loss = self.PH.connectedness_loss()
        return loss

    def connectivity_loss_from_density(self,
                                       density_model: torch.nn.Module,
                                       iso_level: float = 0.5) -> torch.Tensor:
        return self.PH.connectedness_loss_from_density(density_model, iso_level)


    def holes_loss(self):
        loss = self.PH.holes_loss()
        return loss


    def holes_loss_from_density(self,
                                density_model: torch.nn.Module,
                                iso_level: float = 0.5) -> torch.Tensor:
        return self.PH.holes_loss_from_density(density_model, iso_level)

    def learn_original_geometry(self,coords):
        "Learn human generated geometry for validation experiments"
        true_SDF = BRIDGE.interfaces.calculate_SDF(coords).squeeze()
        predicted_SDF = self.GINN_model(coords).squeeze()
        loss = F.mse_loss(predicted_SDF,true_SDF)
        return loss


In [None]:
class PINN_losses(Properties):
    def __init__(self,
                 u_model,
                 v_model,
                 GINN_model,
                 n_opt_samples,
                 training_hparams,
                 enforce_density):
        super().__init__(test_case=BRIDGE)
        self.u_model = u_model.to(self.device)
        self.v_model = v_model.to(self.device)
        self.GINN_model = GINN_model.to(self.device)
        self.n_opt_samples = n_opt_samples
        self.enforce_density = enforce_density
        self.dirichlet_pts = training_hparams['dirichlet_pts']
        self.num_neumann_pts = training_hparams['num_neumann_points']
        self.density_exponent = training_hparams['density_exponent']

    def ritz_loss(self,x):

      # External Work
      neumann_pts = self.interfaces.sample_points_on_neumann_boundary(self.num_neumann_pts, 'vertical', 'torch_tensor').to(self.device)
      neumann_pts.requires_grad_(True)
      R = BRIDGE.interfaces.obstacle_radius
      arc_length = np.pi * R
      ds = arc_length / neumann_pts.shape[0]

      # Apply constant vertical traction
      traction_y = self.force_vector[1] / arc_length
      prescribed_traction = torch.zeros_like(neumann_pts)
      prescribed_traction[:, 1] = traction_y

      u_neu = self.u_model(neumann_pts).squeeze(-1)
      v_neu = self.v_model(neumann_pts).squeeze(-1)
      displacements_neumann = torch.stack([u_neu, v_neu], dim=1)

      work = torch.sum(prescribed_traction * displacements_neumann, dim=1)
      external_energy = torch.sum(work * ds)

      # Internal Strain Energy
      densities = self.GINN_model(x)
      densities, _ = self.enforce_density.apply(x,densities, None, self.n_opt_samples, BRIDGE.domain)

      coords = x.detach().clone().requires_grad_(True).to(self.device)
      u = self.u_model(coords)
      v = self.v_model(coords)

      grad_u = torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
      grad_v = torch.autograd.grad(v, coords, grad_outputs=torch.ones_like(v), create_graph=True, retain_graph=True)[0]

      epsilon_11 = grad_u[:, 0]
      epsilon_22 = grad_v[:, 1]
      epsilon_12 = 0.5 * (grad_u[:, 1] + grad_v[:, 0])
      trace_epsilon = epsilon_11 + epsilon_22

      # Update lame parameters based on the SIMP Method
      densities = densities.squeeze()
      lame_lambda = self.lame_lambda * torch.ones_like(epsilon_11)
      lame_lambda = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_lambda

      lame_mu = self.lame_mu * torch.ones_like(epsilon_11)
      lame_mu = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_mu

      sigma_11 = 2 * (lame_mu * epsilon_11) + (lame_lambda * trace_epsilon)
      sigma_22 = 2 * (lame_mu * epsilon_22) + (lame_lambda * trace_epsilon)
      sigma_12 = 2 * (lame_mu * epsilon_12)

      internal_energy_density = 0.5 * (sigma_11 * epsilon_11 + 2 * sigma_12 * epsilon_12 + sigma_22 * epsilon_22)

      internal_energy = BRIDGE.domain_volume * internal_energy_density.mean()

      # Total Potential Energy
      energy = internal_energy - external_energy

      return energy

In [None]:
class topology_optimization(Properties):
    def __init__(self,
                 u_model,
                 v_model,
                 GINN_model,
                 n_opt_samples,
                 training_hparams,
                 enforce_density):
        super().__init__(test_case=BRIDGE)
        self.u_model = u_model.to(self.device)
        self.v_model = v_model.to(self.device)
        self.GINN_model = GINN_model.to(self.device)
        self.n_opt_samples = n_opt_samples
        self.enforce_density = enforce_density
        self.density_exponent = training_hparams['density_exponent']

    def compute_sensitivities(self,coords):
        self.u_model.eval(); self.v_model.eval(); self.GINN_model.eval()
        with torch.enable_grad():
            densities = self.GINN_model(coords).detach()
            densities, _ = self.enforce_density.apply(coords,densities, None,self.n_opt_samples,BRIDGE.domain)
            densities.requires_grad_(True)

            coords.requires_grad_(True)
            u = self.u_model(coords)
            v = self.v_model(coords)

            grad_u = torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
            grad_v = torch.autograd.grad(v, coords, grad_outputs=torch.ones_like(v), create_graph=True, retain_graph=True)[0]

            epsilon_11 = grad_u[:, 0]
            epsilon_22 = grad_v[:, 1]
            epsilon_12 = 0.5 * (grad_u[:, 1] + grad_v[:, 0])
            trace_epsilon = epsilon_11 + epsilon_22

            densities = densities.squeeze()
            lame_lambda = self.lame_lambda * torch.ones_like(epsilon_11)
            lame_lambda = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_lambda

            lame_mu = self.lame_mu * torch.ones_like(epsilon_11)
            lame_mu = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_mu

            sigma_11 = 2 * (lame_mu * epsilon_11) + (lame_lambda * trace_epsilon)
            sigma_22 = 2 * (lame_mu * epsilon_22) + (lame_lambda * trace_epsilon)
            sigma_12 = 2 * (lame_mu * epsilon_12)

            strain_energy_density = 0.5 * (sigma_11 * epsilon_11 + 2 * sigma_12 * epsilon_12 + sigma_22 * epsilon_22)

            internal_energy = BRIDGE.domain_volume * strain_energy_density.mean()

            loss = internal_energy

            (d_rho,) = torch.autograd.grad(loss,densities, retain_graph=False, create_graph=False, allow_unused=False)
            sensitivities = -d_rho  # match TF custom-gradient sign

        return densities.detach(), sensitivities.detach()

    @staticmethod
    def pad_replicate_2d(t: torch.Tensor, rep: int) -> torch.Tensor:
        """Helper for sensitivity filter - following method from https://github.com/JonasZehn/ntopo/tree/main/ntopo """
        return F.pad(t, (rep, rep, rep, rep), mode="replicate")

    @staticmethod
    def pad_constant_2d(t: torch.Tensor, rep: int, value: float) -> torch.Tensor:
        """Helper for sensitivity filter - following method from https://github.com/JonasZehn/ntopo/tree/main/ntopo """
        return F.pad(t, (rep, rep, rep, rep), mode="constant", value=value)

    def apply_sensitivity_filter_2d(self,
                                    coords: torch.Tensor,
                                    old_densities: torch.Tensor,
                                    sensitivities: torch.Tensor,
                                    n_samples: Tuple[int,int],
                                    domain: np.ndarray,
                                    radius: float) -> torch.Tensor:
        """
        PyTorch analogue of ntopo.apply_sensitivity_filter_3d:
        - adopted from https://github.com/JonasZehn/ntopo/tree/main/ntopo .
        """
        gamma = 1e-3
        nx, ny = n_samples
        N = nx * ny
        assert coords.shape[0] == N

        cell_width = (domain[1] - domain[0]) / nx
        radius_space = radius * cell_width
        fsize = 2*round(radius) + 1
        rep = fsize // 2

        dens = old_densities.view(ny, nx, 1).permute(2,0,1).unsqueeze(0)  # [1,1,ny,nx]
        sens = sensitivities.view(ny, nx, 1).permute(2,0,1).unsqueeze(0)  # [1,1,ny,nx]
        dens_p = self.pad_replicate_2d(dens, rep)
        sens_p = self.pad_replicate_2d(sens, rep)

        k = fsize
        dens_patch = F.unfold(dens_p, kernel_size=k, stride=1)  # [1,k*k,ny*nx]
        sens_patch = F.unfold(sens_p, kernel_size=k, stride=1)  # [1,k*k,ny*nx]

        pos = coords.view(ny, nx, 2).permute(2,0,1).unsqueeze(0)  # [1,2,ny,nx]
        pos_p = self.pad_constant_2d(pos, rep, value=-1000.0)
        pos_patch = F.unfold(pos_p, kernel_size=k, stride=1)  # [1,2*k*k,ny*nx]
        pos_patch = pos_patch.view(1, 2, k*k, ny*nx)
        center = pos.view(1, 2, 1, ny*nx)

        diff = pos_patch - center
        dists = torch.sqrt((diff**2).sum(dim=1) + 1e-35)  # [1,k*k,ny*nx]
        Hei = torch.clamp(radius_space - dists, min=0.0)

        Heixic = Hei * dens_patch * sens_patch
        sum_Heixic = Heixic.sum(dim=1)     # [1,ny*nx]
        sum_Hei = Hei.sum(dim=1)           # [1,ny*nx]

        old_r = old_densities.view(1, -1)  # [1,ny*nx]
        div = torch.clamp(old_r, min=gamma) * sum_Hei
        grads = (sum_Heixic / (div + 1e-35)).t()  # [ny*nx,1]
        return grads

    def apply_sensitivity_filter(self,coords, old_densities, sensitivities, n_samples, domain, radius):
        return self.apply_sensitivity_filter_2d(coords, old_densities, sensitivities, n_samples, domain, radius)

    @torch.no_grad()
    def compute_target_densities(self,
                                 old_densities_list: List[torch.Tensor],
                                 sensitivities_list: List[torch.Tensor],
                                 sample_volume: float,
                                 target_volume: float,
                                 max_move: float,
                                 damping_parameter: float):
        """
        Compute target densities using OC update rule: derived from https://github.com/JonasZehn/ntopo/tree/main/ntopo
        """
        total = sum([odi.numel() for odi in old_densities_list])
        dv = sample_volume / float(total)

        lb_list = [torch.clamp(odi - max_move, 0.0, 1.0) for odi in old_densities_list]
        ub_list = [torch.clamp(odi + max_move, 0.0, 1.0) for odi in old_densities_list]

        def targets_for_lambda(lmbd: float):
            targets = []
            flat_all = []
            for odi, s, lb, ub in zip(old_densities_list, sensitivities_list, lb_list, ub_list):
                Bi = s / (-(dv * lmbd) + 1e-20)
                tgt = odi * torch.pow(torch.clamp(Bi, min=1e-20), damping_parameter)
                tgt = torch.maximum(lb, torch.minimum(ub, tgt))
                tgt = torch.clamp(tgt, 0.0, 1.0)
                targets.append(tgt)
                flat_all.append(tgt.reshape(-1))
            vol = sample_volume * torch.mean(torch.cat(flat_all, dim=0))
            return vol, targets

        lam_lo, lam_hi = 0.0, 1e9
        for _ in range(60):
            lam_mid = 0.5 * (lam_lo + lam_hi)
            vol_mid, _ = targets_for_lambda(lam_mid)
            if (lam_hi - lam_lo) / (lam_hi + lam_lo + 1e-12) < 1e-3:
                break
            if vol_mid > target_volume:
                lam_lo = lam_mid
            else:
                lam_hi = lam_mid

        _, targets = targets_for_lambda(0.5 * (lam_lo + lam_hi))
        return targets

In [None]:
class model_training:
    def __init__(self,
                 u_model,
                 v_model,
                 density_GINN_model,
                 SDF_GINN_model,
                 training_hparams,
                 topo_hparams,
                 constraint_hparams,
                 test_case,
                 loss_weight_hparams,
                 GINN_hparams):
        super().__init__()

        # ---------------- Device ----------------
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else
            "mps" if torch.backends.mps.is_available() else
            "cpu"
        )

        # ---------------- Models ----------------
        self.u_model = u_model.to(self.device)
        self.v_model = v_model.to(self.device)
        self.density_GINN_model = density_GINN_model.to(self.device)
        self.SDF_GINN_model = SDF_GINN_model.to(self.device)

        # ---------------- Hparams ----------------
        self.topo_hparams = topo_hparams
        self.training_hparams = training_hparams
        self.test_case = test_case
        self.loss_weight_hparams = loss_weight_hparams
        self.GINN_hparams = GINN_hparams

        self.save_path = topo_hparams['save_path']
        self.volume_ratio = topo_hparams.get('volume_ratio', None)
        self.lr_PINN = topo_hparams['lr_PINN']
        self.lr_GINN = topo_hparams['lr_GINN']
        self.save_interval = topo_hparams['save_interval']
        self.filter_radius = topo_hparams['filter_radius']
        self.n_opt_iterations = topo_hparams['n_opt_iterations']
        self.n_sim_iterations = topo_hparams['n_sim_iterations']
        self.n_pre_training_iterations_PINN = topo_hparams['n_pre_training_iterations_PINN']
        self.n_pre_training_iterations_GINN = topo_hparams['n_pre_training_iterations_GINN']
        self.n_opt_batches = topo_hparams['n_opt_batches']
        self.seed = topo_hparams['seed']
        self.refine_study = training_hparams["refine_study"]
        self.batch_size = training_hparams['batch_size']
        self.num_points = training_hparams['total_sample_points']
        self.sigma_allow = float(topo_hparams.get('sigma_allow', 1.0))
        self.stress_percentile = float(topo_hparams.get('stress_percentile', 0.995))
        self.rho_mask_threshold = float(topo_hparams.get('rho_mask_threshold', 0.6))
        self.stress_tol = float(topo_hparams.get('stress_tol', 0.02))
        self.vol_frac_min = float(topo_hparams.get('vol_frac_min', 0.02))
        self.vol_frac_max = float(topo_hparams.get('vol_frac_max', 0.98))
        self.vol_step_min = float(topo_hparams.get('vol_step_min', 0.005))
        self.vol_step_max = float(topo_hparams.get('vol_step_max', 0.05))
        self.alpha_decrease = float(topo_hparams.get('alpha_decrease', 0.50))
        self.beta_increase = float(topo_hparams.get('beta_increase', 0.25))
        self.plot_threshold = float(topo_hparams.get('rho_treshold', 0.3))
        self.stress_metric = str(topo_hparams.get("stress_metric", "percentile"))
        self.ks_rho = float(topo_hparams.get("ks_rho", 50.0))
        self.ks_size_correction = bool(topo_hparams.get("ks_size_correction", True))

        if getattr(self.density_GINN_model, "volume_ratio", None) is None:
            self.density_GINN_model.volume_ratio = 0.5

        self.sigma_ema = None
        self.sigma_ema_alpha = 0.2

        self.enforce_density = Density_Constraints(
            constraint_hparams,
            self.test_case,
            self.test_case.interfaces.is_inside_interface_thickness,
            self.test_case.interfaces.is_inside_prohibited_region
        )

        # ---------------- ALM setup for first optimization stage ----------------
        self.use_objective_function = bool(loss_weight_hparams.get('objective_function', False))
        self.objective_losses = list(loss_weight_hparams.get('objective_losses', []))

        all_geom_loss_keys = [
            'Eikonal Loss',
            'Interface Loss',
            'Envelope Loss',
            'Connectivity Loss',
            'Prescribed Normals Loss',
            'Prescribed Thickness Loss',
            'Prohibited Region Loss',
            'Smoothness Loss',
            'Holes Loss',
        ]

        if self.use_objective_function:
            constraint_keys = [k for k in all_geom_loss_keys if k not in self.objective_losses]
        else:
            constraint_keys = [
                'Envelope Loss',
                'Connectivity Loss',
                'Prescribed Normals Loss',
                'Prescribed Thickness Loss',
                'Prohibited Region Loss',
                'Smoothness Loss',
                'Holes Loss',
            ]

        self.scalar_keys = ['Objective'] + constraint_keys

        lambda_init = {'Objective': 1.0}

        if 'Eikonal Loss' in constraint_keys:
            lambda_init['Eikonal Loss'] = GINN_hparams.get('eikonal_loss_weight', 0.0)
        if 'Interface Loss' in constraint_keys:
            lambda_init['Interface Loss'] = GINN_hparams.get('interface_loss_weight', 1.0)
        if 'Envelope Loss' in constraint_keys:
            lambda_init['Envelope Loss'] = GINN_hparams.get('envelope_loss_weight', 1.0)
        if 'Connectivity Loss' in constraint_keys:
            lambda_init['Connectivity Loss'] = GINN_hparams.get('connectivity_loss_weight', 100.0)
        if 'Prescribed Normals Loss' in constraint_keys:
            lambda_init['Prescribed Normals Loss'] = GINN_hparams.get('prescribed_normals_loss_weight', 1.0)
        if 'Prescribed Thickness Loss' in constraint_keys:
            lambda_init['Prescribed Thickness Loss'] = GINN_hparams.get('prescribed_thickness_loss_weight', 1.0)
        if 'Prohibited Region Loss' in constraint_keys:
            lambda_init['Prohibited Region Loss'] = GINN_hparams.get('prohibited_region_loss_weight', 1.0)
        if 'Smoothness Loss' in constraint_keys:
            lambda_init['Smoothness Loss'] = GINN_hparams.get('smoothness_loss_weight', 1e-4)
        if 'Holes Loss' in constraint_keys:
            lambda_init['Holes Loss'] = GINN_hparams.get('holes_loss_weight', 100.0)

        self.alm = ALM(
            loss_keys=self.scalar_keys,
            objective_key='Objective',
            lambda_dict=lambda_init,
            alpha=loss_weight_hparams['alpha'],
            gamma=loss_weight_hparams['gamma'],
            epsilon=loss_weight_hparams['epsilon'],
            device=self.device
        )

        # ---------------- ALM setup for second optimization stage ----------------
        self.scalar_keys_rho = [
            'Topo Objective',
            'Density Envelope Loss',
            'Density Connectivity Loss'
        ]

        lambda_init_rho = {
            'Topo Objective':            1.0,
            'Density Envelope Loss':     GINN_hparams.get('envelope_loss_weight_density', 1.0),
            'Density Connectivity Loss': GINN_hparams.get('connectivity_loss_weight_density', 100.0),
        }

        self.alm_rho = ALM(
            loss_keys=self.scalar_keys_rho,
            objective_key='Topo Objective',
            lambda_dict=lambda_init_rho,
            alpha=loss_weight_hparams['alpha'],
            gamma=loss_weight_hparams['gamma'],
            epsilon=loss_weight_hparams['epsilon'],
            device=self.device
        )

        # ---------------- Gradient clipping ----------------
        self.grad_clipper = AutoClip(
            grad_clipping_on=training_hparams["grad_clipping_on"],
            grad_clip=training_hparams["grad_clip"],
            auto_clip_on=training_hparams["auto_clip_on"],
            auto_clip_percentile=training_hparams["auto_clip_percentile"],
            auto_clip_min_len=training_hparams["auto_clip_min_len"],
            auto_clip_hist_len=training_hparams["auto_clip_hist_len"]
        )

        # ---------------- Samplers ----------------
        self.point_sampler = Point_Sampler(
            test_case.domain,
            test_case.interfaces,
            num_points_domain=self.num_points,
            num_points_interface=0
        )

        inside_env = make_inside_envelope_fn_from_domain(
            domain=self.test_case.domain, nx=self.test_case.dim, device=self.device
        )

        # PH set-up
        self.PH = PH(
            nx=test_case.dim,
            bounds=test_case.domain,
            model=self.SDF_GINN_model,
            n_grid_points=128,
            iso_level=0.0,
            target_betti=[1, 0, 0],
            maxdim=1,
            is_density=False,
            inside_envelope_fn=inside_env,
            group_size_fwd_no_grad=32768,
            add_frame=True,
            hole_level=0.06,
            test_case=self.test_case
        )

        # PH for the density field
        self.PH_rho = PH(
            nx=test_case.dim,
            bounds=test_case.domain,
            model=self.density_GINN_model,
            n_grid_points=128,
            iso_level=self.plot_threshold,
            target_betti=[1, 0, 0],
            maxdim=1,
            is_density=True,
            inside_envelope_fn=inside_env,
            group_size_fwd_no_grad=32768,
            add_frame=True,
            hole_level=0.06,
            test_case=self.test_case
        )

        interface_points = BRIDGE.interfaces.sample_points_from_all_interfaces(
            10001, output_type='torch_tensor'
        )
        self.boundary_sampler = Boundary_Sampler(
            dim=BRIDGE.dim,
            bounds=BRIDGE.domain,
            model=self.SDF_GINN_model,
            x_interface=interface_points,
            n_points_surface=10001,
            interface_cutoff=0.05
        )

        self.boundary_sampler_rho = Boundary_Sampler(
            dim=self.test_case.dim,
            bounds=self.test_case.domain,
            model=self.density_GINN_model,
            x_interface=interface_points,
            n_points_surface=10001,
            interface_cutoff=0.05,
            level_set = 0.5,
        )

        # History buffers
        self.hist_iters = []
        self.hist_compliance = []
        self.hist_volume = []
        self.hist_sigma_metric = []
        self.hist_sigma_metric_KS = []
        self.ginn_hist_iters = []
        self.ginn_hist_eik = []
        self.ginn_hist_env = []
        self.ginn_hist_connect = []
        self.ginn_hist_hole = []
        self.ginn_hist_int = []
        self.ginn_hist_norm = []
        self.ginn_hist_thick = []
        self.ginn_hist_prohib = []
        self.ginn_hist_smooth = []
        self.ginn_hist_total = []
        self.topo_hist_iters = []
        self.topo_hist_topo = []
        self.topo_hist_env_rho = []
        self.topo_hist_conn_rho = []
        self.topo_hist_total = []
        self.topo_step_counter = 0
        self.time_ginn_pretrain = 0.0
        self.time_topology_phase = 0.0
        self.time_density_updates = 0.0
        self.time_total = 0.0
        self._density_update_calls = 0
        self._density_update_time_buffer = []

        # ---------------- paths ----------------
        self.csv_time_path = os.path.join(self.save_path, "timing.csv")
        self.csv_ginn_loss_path = os.path.join(self.save_path, "ginn_losses.csv")
        self.csv_topo_loss_path = os.path.join(self.save_path, "topo_density_losses.csv")
        self.csv_opt_metrics_path = os.path.join(self.save_path, "topo_metrics.csv")
        self.csv_opt_loss_path = os.path.join(self.save_path, "optimization_losses.csv")
        self.csv_eval_metrics_path = os.path.join(self.save_path, "evaluation_metrics.csv")


        self.eval_metrics_enable = bool(self.topo_hparams.get("eval_metrics_enable", True))
        self.eval_grid_nx = int(self.topo_hparams.get("eval_metrics_grid_nx", 256))
        self.eval_grid_ny = int(self.topo_hparams.get("eval_metrics_grid_ny", 128))
        self.eval_batch = int(self.topo_hparams.get("eval_metrics_batch", 200_000))
        self.eval_boundary_n = int(self.topo_hparams.get("eval_metrics_boundary_n", 2_000))
        self.eval_interface_n = int(self.topo_hparams.get("eval_metrics_interface_n", 1_024))
        self.eval_max_boundary_pts = int(self.topo_hparams.get("eval_metrics_max_boundary_pts", 20_000))
        self.eval_cdist_chunk = int(self.topo_hparams.get("eval_metrics_cdist_chunk", 4_096))


        # ---------------- Model checkpoint interval ----------------
        self.model_save_interval = int(topo_hparams.get("model_save_interval", self.save_interval))


    # ------ helpers ---------------------

    def loss_ramp_up(self, epoch: int, start_epoch: int, ramp_epochs: int) -> float:
        """Gradually ramps up the strength of a loss to avoid numerical issues
        when introducing a new loss during training."""
        if epoch < start_epoch:
            return 0.0
        return float(min(1.0, (epoch - start_epoch + 1) / max(1, ramp_epochs)))

    _compute_eval_metrics_2d = compute_eval_metrics_2d
    _maybe_log_eval_metrics = maybe_log_eval_metrics
    _log_timing = log_timing
    _log_ginn_losses_csv = log_ginn_losses_csv
    _log_topo_density_losses_csv = log_topo_density_losses_csv
    _log_optimization_metrics_csv = log_optimization_metrics_csv
    _log_optimization_losses_csv = log_optimization_losses_csv
    _record_density_update_time = record_density_update_time
    _save_models = save_models

    def _compute_losses(self, coords, epoch):
        geometry = GINN_losses(
            self.SDF_GINN_model,
            self.test_case,
            self.GINN_hparams,
            self.PH,
            self.boundary_sampler,
            self.enforce_density
        )

        w = self.GINN_hparams
        ph_ramp = self.loss_ramp_up(epoch, start_epoch=0, ramp_epochs=550)

        eik_raw     = geometry.eikonal_loss(coords)
        int_raw     = geometry.interface_loss()
        env_raw     = geometry.design_envelope_loss()
        conn_raw    = geometry.connectivity_loss()
        norm_raw    = geometry.surface_normal_loss()
        thick_raw   = geometry.prescribed_thickness_loss(coords)
        prohib_raw  = geometry.prohibited_region_loss(coords)
        smooth0_raw = geometry.smoothness_loss(epoch)
        holes_raw   = geometry.holes_loss()

        eik     = eik_raw                                  * w['eikonal_loss_weight']
        intr    = int_raw                                  * w['interface_loss_weight']
        env     = env_raw                                  * w['envelope_loss_weight']
        conn    = conn_raw                                 * (w['connectivity_loss_weight'] * ph_ramp)
        norm    = norm_raw                                 * w['prescribed_normals_loss_weight']
        thick   = thick_raw                                * w['prescribed_thickness_loss_weight']
        prohib  = prohib_raw                               * w['prohibited_region_loss_weight']
        smooth  = self.loss_ramp_up(
                        epoch,
                        w['curv_start_epoch'],
                        w['curv_ramp_epochs']
                  ) * smooth0_raw * w['smoothness_loss_weight']
        holes   = holes_raw                                * (w['holes_loss_weight'] * ph_ramp)

        losses_all = {
            'Eikonal Loss':              torch.relu(eik),
            'Interface Loss':            torch.relu(intr),
            'Envelope Loss':             torch.relu(env),
            'Connectivity Loss':         torch.relu(conn),
            'Prescribed Normals Loss':   torch.relu(norm),
            'Prescribed Thickness Loss': torch.relu(thick),
            'Prohibited Region Loss':    torch.relu(prohib),
            'Smoothness Loss':           torch.relu(smooth),
            'Holes Loss':                torch.relu(holes),
        }

        # --------- Build single objective scalar ----------
        if self.use_objective_function and len(self.objective_losses) > 0:
            objective = None
            for name in self.objective_losses:
                if name not in losses_all:
                    raise KeyError(
                        f"Objective loss '{name}' not found. "
                        f"Available: {list(losses_all.keys())}"
                    )
                objective = losses_all[name] if objective is None else (objective + losses_all[name])
        else:
            objective = losses_all['Eikonal Loss'] + losses_all['Interface Loss']

        losses = {'Objective': torch.relu(objective)}

        for key in self.scalar_keys:
            if key == 'Objective':
                continue
            if key not in losses_all:
                raise KeyError(
                    f"Constraint key '{key}' not found in losses_all. "
                    f"Available: {list(losses_all.keys())}"
                )
            losses[key] = losses_all[key]

        return losses

    # ----------------------------------------------------------------------
    @staticmethod
    def _safe_clip_ratio(v, lo, hi, eps=1e-3):
        lo = max(lo, eps)
        hi = min(hi, 1.0 - eps)
        return float(np.clip(v, lo, hi))

    def _estimate_compliance_internal_energy(self, n_opt_samples):
        self.u_model.eval()
        self.v_model.eval()
        self.density_GINN_model.eval()

        nx, ny = n_opt_samples
        xs = get_grid_centers(BRIDGE.domain, [nx, ny]).astype(np.float32)
        xt = torch.tensor(xs, dtype=torch.float32, device=self.device, requires_grad=True)

        rho = self.density_GINN_model(xt)
        rho, _ = self.enforce_density.apply(xt, rho, None, n_opt_samples, BRIDGE.domain)
        r = rho.view(-1).clamp(0.0, 1.0)

        u = self.u_model(xt)
        v = self.v_model(xt)

        gu = torch.autograd.grad(u, xt, grad_outputs=torch.ones_like(u),
                                 create_graph=False, retain_graph=True)[0]
        gv = torch.autograd.grad(v, xt, grad_outputs=torch.ones_like(v),
                                 create_graph=False, retain_graph=False)[0]

        eps11 = gu[:, 0]
        eps22 = gv[:, 1]
        eps12 = 0.5 * (gu[:, 1] + gv[:, 0])
        tr = eps11 + eps22

        p = float(self.training_hparams['density_exponent'])
        lam0 = Properties(BRIDGE).lame_lambda
        mu0 = Properties(BRIDGE).lame_mu
        lam = r.pow(p) * lam0
        mu = r.pow(p) * mu0

        s11 = 2.0 * mu * eps11 + lam * tr
        s22 = 2.0 * mu * eps22 + lam * tr
        s12 = 2.0 * mu * eps12

        sed = 0.5 * (s11 * eps11 + 2.0 * s12 * eps12 + s22 * eps22)
        internal_energy = BRIDGE.domain_volume * sed.mean()
        return float(internal_energy.detach().cpu().item())


    _save_training_curves  = save_training_curves
    _save_ginn_training_curves = save_ginn_training_curves
    _save_topology_density_curves = save_topology_density_curves
    _save_stress_histogram = save_stress_histogram



    # ---------------- Stress/volume measurement ----------------
    def _measure_stress_volume_percentile_2d(self):
        self.u_model.eval()
        self.v_model.eval()
        self.density_GINN_model.eval()

        u_img, v_img, sigma_img, rho_img = predict_uv_sigma_image_2d_binary(
            self.u_model,
            self.v_model,
            self.density_GINN_model,
            BRIDGE.domain,
            test_case=self.test_case,
            rho_threshold_plot=self.plot_threshold
        )

        solid = (rho_img >= self.plot_threshold)
        if not np.any(solid):
            flat = rho_img.reshape(-1)
            t = float(np.quantile(flat, 0.5))
            solid = (rho_img >= t)

        sigma_vals = sigma_img[solid] if np.any(solid) else sigma_img.reshape(-1)

        q = float(np.clip(self.stress_percentile, 0.0, 1.0))
        sigma_metric = float(np.quantile(sigma_vals, q))
        sigma_max = float(np.max(sigma_vals))

        current_volume = BRIDGE.domain_volume * (
            float(np.count_nonzero(solid)) / solid.size
        )
        return sigma_metric, sigma_max, current_volume

    def _measure_stress_volume_KS_2d(self):
        self.u_model.eval()
        self.v_model.eval()
        self.density_GINN_model.eval()

        u_img, v_img, sigma_img, rho_img = predict_uv_sigma_image_2d_binary(
            self.u_model,
            self.v_model,
            self.density_GINN_model,
            BRIDGE.domain,
            test_case=self.test_case,
            rho_threshold_plot=self.plot_threshold
        )

        solid = (rho_img >= self.plot_threshold)
        if not np.any(solid):
            flat = rho_img.reshape(-1)
            t = float(np.quantile(flat, 0.5))
            solid = (rho_img >= t)

        sigma_vals = sigma_img[solid]
        if sigma_vals.size == 0:
            return 0.0, 0.0, 0.0

        volume_abs = BRIDGE.domain_volume * (
            float(np.count_nonzero(solid)) / solid.size
        )

        s = sigma_vals.astype(np.float64)
        g = s / float(self.sigma_allow) - 1.0
        gmax = float(np.max(g))

        lse = gmax + (1.0 / self.ks_rho) * float(
            np.log(np.sum(np.exp(self.ks_rho * (g - gmax))) + 1e-300)
        )
        if self.ks_size_correction:
            lse -= (1.0 / self.ks_rho) * float(np.log(s.size))

        ks_val = lse
        sigma_metric = float(self.sigma_allow) * (1.0 + ks_val)
        sigma_max = float(np.max(s))

        return sigma_metric, sigma_max, volume_abs

    def _measure_stress_volume_sigma_max_2d(self):
        _, sigma_max, vol_abs = self._measure_stress_volume_percentile_2d()
        sigma_metric = sigma_max
        return sigma_metric, sigma_max, vol_abs

    def _measure_stress_volume_2d(self):
        metric = str(getattr(self, "stress_metric", "percentile")).lower()

        if metric == "ks":
            return self._measure_stress_volume_KS_2d()
        elif metric == "percentile":
            return self._measure_stress_volume_percentile_2d()
        elif metric == "sigma_max":
            return self._measure_stress_volume_sigma_max_2d()
        else:
            raise ValueError(
                f"Unknown stress metric: {metric}. "
                f"Choose between 'KS' and 'percentile'."
            )

    # ----------------------------------------------------------------------
    # training steps
    # ----------------------------------------------------------------------
    def shape_generation_step(self, optimizer, epoch):
        batch_points = next(self.point_sampler)
        dataset = TensorDataset(batch_points)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        for _, pts in enumerate(loader):
            optimizer.zero_grad()
            points = pts[0].to(self.device)

            self.PH.invalidate_cache()

            losses = self._compute_losses(points.clone().detach().requires_grad_(True), epoch)

            total_loss = self.alm.build(losses)
            total_loss.backward()

            if getattr(self, "grad_clipper", None) is not None and self.grad_clipper.grad_clip_enabled:
                total_sq = 0.0
                for p in self.SDF_GINN_model.parameters():
                    if p.grad is not None:
                        g = p.grad.data
                        if torch.isfinite(g).all():
                            total_sq += g.float().pow(2).sum().item()
                grad_norm = math.sqrt(total_sq)
                self.grad_clipper.update_gradient_norm_history(grad_norm)
                clip_val = self.grad_clipper.get_clip_value()
                if np.isfinite(clip_val):
                    torch.nn.utils.clip_grad_norm_(self.SDF_GINN_model.parameters(), clip_val)

            optimizer.step()
            self.alm.update(losses)


    def density_PINN_initialization_step(self, coords, density_GINN_optimizer):

        if self.refine_study == True: # initialize geometry from existing geometry
            self.density_GINN_model.train()
            self.SDF_GINN_model.eval()

            path = "/content/gdrive/Othercomputers/My Mac/Code/hard_GINN_results/model-000250.pt"
            checkpoint = torch.load(path, map_location=self.device)
            self.SDF_GINN_model.load_state_dict(checkpoint['model_state_dict'])
            sdf_density = self.SDF_GINN_model(coords)

            sdf_density = torch.sigmoid(-1000 * sdf_density).detach()
            topo_init = torch.clamp(sdf_density, 0.0, 0.7)

            density_pred = self.density_GINN_model(coords)
            loss = F.mse_loss(density_pred, topo_init)

            density_GINN_optimizer.zero_grad(set_to_none=True)
            loss.backward()
            density_GINN_optimizer.step()

        else: # Normal training setup
            self.density_GINN_model.train()
            self.SDF_GINN_model.eval()

            SDF_GINN_initial_density = self.SDF_GINN_model(coords)
            SDF_GINN_initial_density = torch.sigmoid(-1000 * SDF_GINN_initial_density).detach()
            topo_initial_density = torch.clamp(SDF_GINN_initial_density, 0.0, 0.5)

            density_GINN_density = self.density_GINN_model(coords)
            initial_density_loss = F.mse_loss(density_GINN_density, topo_initial_density)

            density_GINN_optimizer.zero_grad(set_to_none=True)
            initial_density_loss.backward()
            density_GINN_optimizer.step()

    def PINN_update_step(self,
                         coords,
                         PINN_optimizer,
                         n_opt_samples):
        self.u_model.train()
        self.v_model.train()
        self.density_GINN_model.eval()

        physics_loss = PINN_losses(
            self.u_model,
            self.v_model,
            self.density_GINN_model,
            n_opt_samples,
            self.training_hparams,
            self.enforce_density
        )

        PINN_optimizer.zero_grad(set_to_none=True)
        loss = physics_loss.ritz_loss(coords)
        loss.backward()
        PINN_optimizer.step()
        return None

    def topology_optimization_step(self,
                                   density_GINN_optimizer,
                                   coords,
                                   target_densities):

        t0 = time.perf_counter()

        self.density_GINN_model.train()
        density_GINN_optimizer.zero_grad(set_to_none=True)

        densities = self.density_GINN_model(coords)
        topology_optim_loss = F.mse_loss(densities, target_densities)

        self.PH_rho.invalidate_cache()

        geometry_rho = GINN_losses(
            self.density_GINN_model,
            self.test_case,
            self.GINN_hparams,
            self.PH_rho,
            self.boundary_sampler,
            self.enforce_density
        )

        env_rho = geometry_rho.design_envelope_loss_from_density(
            density_model=self.density_GINN_model,
            iso_level=float(self.plot_threshold)
        )

        conn_rho = geometry_rho.connectivity_loss_from_density(
            density_model=self.density_GINN_model,
            iso_level=float(self.plot_threshold)
        )

        losses_rho = {
            'Topo Objective':             torch.relu(topology_optim_loss),
            'Density Envelope Loss':      torch.relu(env_rho),
            'Density Connectivity Loss':  torch.relu(conn_rho),
        }

        total_loss = self.alm_rho.build(losses_rho)
        self.topo_step_counter += 1
        self.topo_hist_iters.append(self.topo_step_counter)
        self.topo_hist_topo.append(float(topology_optim_loss.item()))
        self.topo_hist_env_rho.append(float(env_rho.item()))
        self.topo_hist_conn_rho.append(float(conn_rho.item()))
        total_raw = float(topology_optim_loss.item() + env_rho.item() + conn_rho.item())
        self.topo_hist_total.append(total_raw)

        self._log_topo_density_losses_csv(
            step=self.topo_step_counter,
            topo_obj=topology_optim_loss.item(),
            env_rho=env_rho.item(),
            conn_rho=conn_rho.item(),
            total_raw=total_raw
        )
        # -------------------------------------------

        total_loss.backward()

        if getattr(self, "grad_clipper", None) is not None and self.grad_clipper.grad_clip_enabled:
            total_sq = 0.0
            for p in self.density_GINN_model.parameters():
                if p.grad is not None:
                    g = p.grad.data
                    if torch.isfinite(g).all():
                        total_sq += g.float().pow(2).sum().item()
            grad_norm = math.sqrt(total_sq)
            self.grad_clipper.update_gradient_norm_history(grad_norm)
            clip_val = self.grad_clipper.get_clip_value()
            if np.isfinite(clip_val):
                torch.nn.utils.clip_grad_norm_(self.density_GINN_model.parameters(), clip_val)

        density_GINN_optimizer.step()

        self.alm_rho.update(losses_rho)

        dt = time.perf_counter() - t0
        self._record_density_update_time(dt)

        return None

    # ----------------------------------------------------------------------
    # main driver
    # ----------------------------------------------------------------------
    def generate_geometry(self):
        import os
        import time
        import matplotlib.pyplot as plt

        set_random_seed(self.seed)
        os.makedirs(self.save_path, exist_ok=True)

        t_total0 = time.perf_counter()

        DEFAULT_N_SIM_SAMPLES_2D = 150 * 50
        DEFAULT_N_OPT_SAMPLES_2D = 150 * 50

        n_sim_samples = get_default_sample_counts(BRIDGE.domain, DEFAULT_N_SIM_SAMPLES_2D)
        n_opt_samples = get_default_sample_counts(BRIDGE.domain, DEFAULT_N_OPT_SAMPLES_2D)

        PINN_model_params_list = list(self.u_model.parameters()) + list(self.v_model.parameters())
        PINN_optimizer = torch.optim.Adam(PINN_model_params_list, lr=self.lr_PINN, betas=(0.9, 0.99))
        density_GINN_optimizer = torch.optim.Adam(self.density_GINN_model.parameters(), lr=self.lr_GINN, betas=(0.8, 0.9))
        SDF_GINN_optimizer = torch.optim.Adam(self.SDF_GINN_model.parameters(), lr=self.lr_GINN)

        PINN_point_sampler = gen_samples(BRIDGE.domain, n_sim_samples)
        GINN_point_sampler = gen_samples(BRIDGE.domain, n_opt_samples)

        # ===================== First optimization stage =====================
        print("=== Generating Geometry: ===")
        t_ginn0 = time.perf_counter()

        for it in tqdm(range(self.n_pre_training_iterations_GINN), desc="GINN Training"):
            self.shape_generation_step(SDF_GINN_optimizer, epoch=it)

            if (it % 40) == 0:
                fig = plot_GINN_geometry(self.test_case, 80000, self.SDF_GINN_model, 0.5)
                fig.savefig(os.path.join(self.save_path, f"ginn-geometry-{it:06d}.png"), dpi=150)
                plt.close(fig)

                geometry_losses = GINN_losses(
                    self.SDF_GINN_model,
                    self.test_case,
                    self.GINN_hparams,
                    self.PH,
                    self.boundary_sampler,
                    self.enforce_density
                )

                coords = next(self.point_sampler).to(self.device)
                eik_loss = geometry_losses.eikonal_loss(coords)
                env_loss = geometry_losses.design_envelope_loss()
                connect_loss = geometry_losses.connectivity_loss()
                hole_loss = geometry_losses.holes_loss()
                int_loss = geometry_losses.interface_loss()
                norm_loss = geometry_losses.surface_normal_loss()
                thick_loss = geometry_losses.prescribed_thickness_loss(coords)
                prohib_loss = geometry_losses.prohibited_region_loss(coords)
                smooth_loss = geometry_losses.smoothness_loss(it)

                smooth_loss_scaled = self.loss_ramp_up(
                    it,
                    self.GINN_hparams['curv_start_epoch'],
                    self.GINN_hparams['curv_ramp_epochs']
                ) * smooth_loss

                total_loss = (
                    eik_loss + env_loss + connect_loss + hole_loss + int_loss +
                    norm_loss + thick_loss + prohib_loss + smooth_loss_scaled
                )

                print(f"Eikonal Loss: {eik_loss.item():.6f}")
                print(f"Envelope Loss: {env_loss.item():.6f}")
                print(f"Connectivity Loss: {connect_loss.item():.6f}")
                print(f"Holes Loss: {hole_loss.item():.6f}")
                print(f"Interface Loss: {int_loss.item():.6f}")
                print(f"Surface Normal Loss: {norm_loss.item():.6f}")
                print(f"Prescribed Thickness Loss: {thick_loss.item():.6f}")
                print(f"Prohibited Region Loss: {prohib_loss.item():.6f}")
                print(f"Smoothness Loss: {smooth_loss_scaled.item():.6f}")
                print(f"Total Loss: {total_loss.item():.6f}")
                print(' ')
                print(' ')

                self.ginn_hist_iters.append(it)
                self.ginn_hist_eik.append(eik_loss.item())
                self.ginn_hist_env.append(env_loss.item())
                self.ginn_hist_connect.append(connect_loss.item())
                self.ginn_hist_hole.append(hole_loss.item())
                self.ginn_hist_int.append(int_loss.item())
                self.ginn_hist_norm.append(norm_loss.item())
                self.ginn_hist_thick.append(thick_loss.item())
                self.ginn_hist_prohib.append(prohib_loss.item())
                self.ginn_hist_smooth.append(smooth_loss_scaled.item())
                self.ginn_hist_total.append(total_loss.item())

                # CSV log for GINN losses
                self._log_ginn_losses_csv(
                    it,
                    eik_loss.item(),
                    env_loss.item(),
                    connect_loss.item(),
                    hole_loss.item(),
                    int_loss.item(),
                    norm_loss.item(),
                    thick_loss.item(),
                    prohib_loss.item(),
                    smooth_loss_scaled.item(),
                    total_loss.item()
                )

                self._save_ginn_training_curves(it)

                # periodic checkpoint during GINN pretrain
                if (it % self.model_save_interval) == 0:
                    self._save_models(tag=f"ginn-{it:06d}")
                self._maybe_log_eval_metrics(phase="GINN", it=it, model_kind="sdf")


        self.time_ginn_pretrain = time.perf_counter() - t_ginn0
        self._log_timing("ginn_pretrain_total", self.time_ginn_pretrain)

        # Initialize volume ratio from SDF
        grid_xy = get_grid_centers(BRIDGE.domain, n_opt_samples).astype(np.float32)
        xt = torch.tensor(grid_xy, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            sdf_field = self.SDF_GINN_model(xt).squeeze()
            occ = torch.sigmoid(-1000 * sdf_field).clamp(0, 1)
            init_vol_frac = float(occ.mean().item())

        start_ratio = self.volume_ratio if self.volume_ratio is not None else init_vol_frac
        self.volume_ratio = self._safe_clip_ratio(
            start_ratio, self.vol_frac_min, self.vol_frac_max
        )
        self.density_GINN_model.volume_ratio = float(self.volume_ratio)
        target_volume = float(self.volume_ratio) * float(BRIDGE.domain_volume)

        # ===================== PINN pre-training =====================
        print("=== Pre-Training PINN: ===")
        for _ in tqdm(range(self.n_pre_training_iterations_PINN), desc="PINN Pre-Training"):
            coords = next(PINN_point_sampler)
            self.PINN_update_step(coords, PINN_optimizer, n_opt_samples)
        print(' ')
        print(' ')

        # ===================== density GINN pre-training =====================
        print("=== Pre-Training density GINN: ===")
        for _ in tqdm(range(self.n_pre_training_iterations_PINN), desc="GINN Training"):
            coords = next(GINN_point_sampler)
            self.density_PINN_initialization_step(coords, density_GINN_optimizer)
        print(' ')
        print(' ')

        # ---- Initial saves ---------------------------
        rho_img0, _ = predict_densities_image_2d(
            self.density_GINN_model, BRIDGE.domain, n_opt_samples, self.enforce_density
        )
        save_densities_to_file(
            rho_img0,
            os.path.join(self.save_path, f"density-{'%06d' % 0}.png"),
            BRIDGE.domain
        )
        print('saving initial geometry image to', self.save_path)

        u0, v0, s0, r0 = predict_uv_sigma_image_2d_binary(
            self.u_model,
            self.v_model,
            self.density_GINN_model,
            BRIDGE.domain,
            test_case=self.test_case,
            rho_threshold_plot=self.plot_threshold
        )
        save_uv_sigma_to_file(
            u0, v0, s0, r0,
            os.path.join(self.save_path, f"uvsigma-{'%06d' % 0}.png"),
            BRIDGE.domain,
            test_case = self.test_case,
            rho_threshold=self.plot_threshold
        )

        save_density_scatter_iso_to_file(
            self.density_GINN_model,
            BRIDGE.domain,
            os.path.join(self.save_path, f"iso-{'%06d' % 0}.png"),
            grid_resolution=600,
            iso=self.plot_threshold
        )

        sigma0, sigma0_max, vol0 = self._measure_stress_volume_2d()
        comp0 = self._estimate_compliance_internal_energy(n_opt_samples)
        self.hist_iters.append(0)
        self.hist_sigma_metric.append(sigma0)
        self.hist_volume.append(vol0)
        self.hist_compliance.append(comp0)
        self._save_training_curves(0)
        self._save_stress_histogram(0, s0, r0, sigma0, sigma0_max)
        self._log_optimization_metrics_csv(0, sigma0, sigma0_max, vol0, comp0)
        self._maybe_log_eval_metrics(phase="TOPO", it=0, model_kind="density")
        self._save_models(tag="init")

        # ===================== MAIN OPT LOOP =====================
        t_topo_phase0 = time.perf_counter()

        for it in range(1, self.n_opt_iterations + 1):
            print(f"\n=== Optimization iteration {it}/{self.n_opt_iterations} ===")

            # ---- Update PINN ----
            for _ in tqdm(range(self.n_sim_iterations), desc="Updating PINN"):
                coords = next(PINN_point_sampler)
                self.PINN_update_step(coords, PINN_optimizer, n_opt_samples)

            # ---- Build OC targets ----
            old_densities_list = []
            sensitivities_list = []
            coords_list = []

            topo_funcs = topology_optimization(
                self.u_model,
                self.v_model,
                self.density_GINN_model,
                n_opt_samples,
                self.training_hparams,
                self.enforce_density
            )

            for _ in tqdm(range(self.n_opt_batches), desc="Computing Target Densities"):
                coords = next(GINN_point_sampler)
                densities, sensitivities = topo_funcs.compute_sensitivities(coords)

                densities, sensitivities = self.enforce_density.apply(
                    coords,
                    densities,
                    sensitivities,
                    n_opt_samples,
                    BRIDGE.domain
                )

                filtered_sensitivities = topo_funcs.apply_sensitivity_filter(
                    coords,
                    densities,
                    sensitivities,
                    n_samples=n_opt_samples,
                    domain=BRIDGE.domain,
                    radius=self.filter_radius
                )

                old_densities_list.append(densities)
                sensitivities_list.append(filtered_sensitivities)
                coords_list.append(coords)

            target_density_list = topo_funcs.compute_target_densities(
                old_densities_list,
                sensitivities_list,
                sample_volume=BRIDGE.domain_volume,
                target_volume=target_volume,
                max_move=0.2,
                damping_parameter=0.5
            )

            # ---- Apply density update pass ----
            for coords, target_density in tqdm(
                    list(zip(coords_list, target_density_list)),
                    desc="Optimizing density field",
                    total=len(coords_list)):
                self.topology_optimization_step(
                    density_GINN_optimizer,
                    coords,
                    target_density
                )

            # ---- Stress & volume on binary geometry ----
            sigma_metric, sigma_max, current_volume = self._measure_stress_volume_2d()


            # --- EMA to fight oscillations ---
            if self.sigma_ema is None:
                self.sigma_ema = sigma_metric
            else:
                a = self.sigma_ema_alpha
                self.sigma_ema = (1.0 - a) * self.sigma_ema + a * sigma_metric

            u_img, v_img, sigma_img, rho_img = predict_uv_sigma_image_2d_binary(
                self.u_model,
                self.v_model,
                self.density_GINN_model,
                BRIDGE.domain,
                test_case=self.test_case,
                rho_threshold_plot=self.plot_threshold
            )

            # ---- Compliance on current SIMP design ----
            comp_now = self._estimate_compliance_internal_energy(n_opt_samples)

            # ---- Log history ----
            self.hist_iters.append(it)
            self.hist_sigma_metric.append(sigma_metric)
            self.hist_volume.append(current_volume)
            self.hist_compliance.append(comp_now)

            # ---- CSV metrics log every iteration ----
            self._log_optimization_metrics_csv(
                it, sigma_metric, sigma_max, current_volume, comp_now
            )

            # ---- Periodic curve save ----
            if (it % self.save_interval) == 0:
                self._save_training_curves(it)

            # ---- Volume ratio controller ----
            r = self.sigma_ema / (self.sigma_allow + 1e-20)

            if r < 1.0 - self.stress_tol:
                gap = 1.0 - r
                step = float(np.clip(self.alpha_decrease * gap,
                                     self.vol_step_min,
                                     self.vol_step_max))
                new_ratio = float(self.volume_ratio) * (1.0 - step)
            elif r > 1.0 + self.stress_tol:
                gap = r - 1.0
                step = float(np.clip(self.beta_increase * gap,
                                     self.vol_step_min,
                                     self.vol_step_max))
                new_ratio = float(self.volume_ratio) * (1.0 + step)
            else:
                new_ratio = float(self.volume_ratio)

            self.volume_ratio = self._safe_clip_ratio(
                new_ratio, self.vol_frac_min, self.vol_frac_max
            )
            self.density_GINN_model.volume_ratio = float(self.volume_ratio)
            target_volume = float(self.volume_ratio) * float(BRIDGE.domain_volume)

            print(
                f"[iter {it}] C={comp_now:.4g} | σ_metric={sigma_metric:.4g} (σ_allow={self.sigma_allow:.4g}) "
                f"| σ_max={sigma_max:.4g} | V={current_volume:.6g} "
                f"| vol_ratio→{self.volume_ratio:.4f} "
                f"(target={target_volume/BRIDGE.domain_volume:.4f}·V_domain)"
            )

            # ---- Periodic saves ----
            if (it % self.save_interval) == 0:
                print('Saving images to', self.save_path)
                self._maybe_log_eval_metrics(phase="TOPO", it=it, model_kind="density")

                save_density_scatter_iso_to_file(
                    self.density_GINN_model,
                    BRIDGE.domain,
                    os.path.join(self.save_path, f"iso-{it:06d}.png"),
                    grid_resolution=600,
                    iso=self.plot_threshold
                )

                rho_img_save, _ = predict_densities_image_2d(
                    self.density_GINN_model,
                    BRIDGE.domain,
                    n_opt_samples,
                    self.enforce_density
                )
                save_densities_to_file(
                    rho_img_save,
                    os.path.join(self.save_path, f"density-{it:06d}.png"),
                    BRIDGE.domain
                )

                save_uv_sigma_to_file(
                    u_img, v_img, sigma_img, rho_img,
                    os.path.join(self.save_path, f"uvsigma-{it:06d}.png"),
                    BRIDGE.domain,
                    test_case = self.test_case,
                    rho_threshold=self.plot_threshold
                )

                self._save_stress_histogram(it,
                                            sigma_img,
                                            rho_img,
                                            sigma_metric,
                                            sigma_max)

                self._save_topology_density_curves(it)

                self._log_optimization_losses_csv(
                    it, comp_now, sigma_metric, sigma_max, current_volume
                )

                # ---- Model checkpoint ----
                if (it % self.model_save_interval) == 0:
                    self._save_models(tag=f"opt-{it:06d}")

        self.time_topology_phase = time.perf_counter() - t_topo_phase0
        self._log_timing("topology_phase_total", self.time_topology_phase)
        self._log_timing("density_updates_total", self.time_density_updates)

        # ---- Final density & iso ----
        rho_img, _ = predict_densities_image_2d(
            self.density_GINN_model, BRIDGE.domain, n_opt_samples, self.enforce_density
        )
        save_densities_to_file(
            rho_img,
            os.path.join(self.save_path, f"density-final.png"),
            BRIDGE.domain
        )

        save_density_scatter_iso_to_file(
            self.density_GINN_model,
            BRIDGE.domain,
            os.path.join(self.save_path, f"iso-final.png"),
            grid_resolution=600,
            iso=self.plot_threshold
        )

        # Final curves snapshot
        if len(self.hist_iters) > 0:
            self._save_training_curves(self.hist_iters[-1])

        # ---- Final total time ----
        self.time_total = time.perf_counter() - t_total0
        self._log_timing(
            "final_total",
            self.time_total,
            extra={
                "ginn_pretrain": self.time_ginn_pretrain,
                "topology_phase": self.time_topology_phase,
                "density_updates": self.time_density_updates
            }
        )

        # Final checkpoint
        self._save_models(tag="final")
        self._maybe_log_eval_metrics(phase="TOPO", it=int(self.n_opt_iterations), model_kind="density")

        print("Done. Results in:", self.save_path)


In [None]:
hparams_model = {
    'Model_type'         : 'SIREN',
    'num_hidden_layers'  : 3,
    'num_hidden_neurons' : 280,

    'SIREN_hparams':
                    {
                        'Model_type' : 'SIREN',
                        'layers'     : [180,180,180,180],
                        'dimensionality': 2,
                        'w0_initial' : 1,
                        'w0'         : 1,
                        'skip_connection' : True,
                    },
    'GINN_SIREN_hparams':
                    {
                        'Model_type' : 'SIREN',
                        'layers'     : [180,180,180,180],
                        'dimensionality': 2,
                        'w0_initial' : 15,
                        'w0'         : 2,
                        'skip_connection' : True,
                    },
    'WIRE_hparams':
                    {   'Model_type' : 'WIRE',
                        'layers'     : [80,80,80,80],
                        'dimensionality': 2,
                        'w0_initial' : 30,
                        'w0'         : 10,
                        'sigma0'     : 10,
                        'sigma0_initial' : 10,
                        'layer_type': 'real_gabor',
                        'trainable' : False,
                        'skip_connection' : True,
                    },
    'MLP_hparams':
                    {
                        'Model_type' : 'MLP',
                        'layers'     : [180,180,180,180],
                        'dimensionality': 2,
                        'activation_function' : 'softplus',
                        'use_bias'        : True,
                        'use_batch_norm'  : False,
                        'use_dropout'     : False,
                        'dropout_rate'    : 0.1,
                        'skip_connection' : True,
                    },
}

hparams_feature_expansion = {
    'Feature Type'      : 'None',
    'Num Frequencies'   : 3,
    'Max Frequency'     : 100,
}

density_constraint_hparams= {
'enabled': True,
'priority':"keep_over_prohibit",
}

topo_hparams = {
    'save_path':     "./B_DELETE",
    'volume_ratio':  0.5,
    'lr_PINN':       3e-4,
    'lr_GINN':       3e-4,
    'save_interval': 5,
    'filter_radius': 2.0,
    'n_opt_iterations': 200,
    'n_sim_iterations': 1000,
    'n_pre_training_iterations_PINN': 2500,
    'n_pre_training_iterations_GINN': 500, #800
    'n_opt_batches': 50,
    'seed':          43,
    'rho_treshold':       0.4,
    'sigma_allow':        2000.0,
    'stress_metric':     'percentile', #KS or percentile
    'stress_percentile':  0.995,
    'rho_mask_threshold': 0.6,
    'stress_tol':         0.02,
    'ks_rho': 50.0,
    'vol_frac_min':      0.02,
    'vol_frac_max':      0.98,
    'vol_step_min':      0.005,
    'vol_step_max':      0.05,
    'alpha_decrease':    0.50,
    'beta_increase':     0.25,
}

training_hparams = {
    'total_sample_points': 10000,
    'batch_size': 5000,
    'num_neumann_points': 10000,
    'dirichlet_pts': 5000,
    'mollifier_alpha': 1,
    'density_alpha': 2,
    'density_exponent': 3,
    "grad_clipping_on": True,
    "grad_clip": 0.5,
    "auto_clip_on": True,
    "auto_clip_percentile": 0.9,
    "auto_clip_hist_len": 100,
    "auto_clip_min_len": 10,
    "refine_study": False,

}


GINN_hparams ={
'eikonal_loss_weight' : 0.1,
'envelope_loss_weight':1,
'connectivity_loss_weight':1,
'holes_loss_weight':1,
'interface_loss_weight':1,
'prescribed_normals_loss_weight':1,
'prescribed_thickness_loss_weight':1,
'prohibited_region_loss_weight':1,
'smoothness_loss_weight': 1e-4,
'envelope_loss_weight_density':     1,
'connectivity_loss_weight_density': 1,
'num_points_envelope_loss':10000,
'num_points_connectivity_loss':50000,
'envelope_extension_factor': 0.2,
'num_points_interface_loss': 4096,
'num_points_normals_loss': 4096,
'clip_max_value': 1.0e+6,
'clip_min_value': 0,
'max_curv': 0,
'curv_start_epoch': 200,
'curv_ramp_epochs':100,
}


adaptive_weighting_hparams = {
        'use_adaptive_weighting': True,
        'alpha': 0.90,
        'gamma': 1e-2,
        'epsilon': 1e-8,
        'objective_function': True,
        'objective_losses': ['Smoothness Loss'],
    }


In [None]:

rho_GINN = density_GINN(hparams_model['SIREN_hparams'],
                            hparams_feature_expansion,
                            density_alpha=training_hparams['density_alpha'],
                            volume_ratio=topo_hparams['volume_ratio'])

sdf_GINN = SDF_GINN(hparams_model['GINN_SIREN_hparams'],
                    hparams_feature_expansion)

u_model = PINN(hparams_model['SIREN_hparams'],
               hparams_feature_expansion,
               training_hparams['mollifier_alpha'])

v_model = PINN(hparams_model['SIREN_hparams'],
               hparams_feature_expansion,
               training_hparams['mollifier_alpha'])

Shape_Generator = model_training(u_model,
                                 v_model,
                                 rho_GINN,
                                 sdf_GINN,
                                 training_hparams,
                                 topo_hparams,
                                 density_constraint_hparams,
                                 BRIDGE,
                                 adaptive_weighting_hparams,
                                 GINN_hparams,
                                 )


In [None]:
Shape_Generator.generate_geometry()