In [None]:
import os
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
google_drive_path = '' 
os.chdir(google_drive_path)
print("Current working directory:", os.getcwd())
!ls

!python -m pip install trimesh
!python -m pip install rtree
!python -m pip install cripser==0.0.15

In [None]:
import os
import sys
import csv
import time
import math
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

import trimesh
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import cripser

from pathlib import Path
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib import colors as mcolors
from skimage import measure

from torchsummary import summary
from scipy.interpolate import griddata
from tqdm import tqdm, trange
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tabulate import tabulate

# To allow me to import the functions from other folders from the parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))


In [None]:
from Test_Cases.Jet_Engine_Bracket.JEB_Master_object import JEB_Master_object
JEB = JEB_Master_object(Normalize=True, Symmetry=False, Expand = False, expansion_factor = 1.1)
JEB.create_interfaces()

from Functions.utils import *
from Functions.Point_Sampling.point_sampler import *
from Functions.Point_Sampling.BS3D import *
from Functions.Plotting_functions.JEB_results import *
from Functions.Plotting_functions.training_curves.JEB_curves import *
from Functions.Computations.PH3D import *
from Functions.Computations.eval3D import *
from Functions.Training.enforce_Const_3D import *
from Functions.Training.ALM import *
from Functions.logging.JEB_logging import *
from Functions.Training.Properties import *

from Models.GINN_Models.GINN import GINN
from File_Paths.file_paths import interfaces_path, mesh_path

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
class SDF_GINN(torch.nn.Module):
    def __init__(self,
                 hparams_model,
                 hparams_feature_expansion):

        super().__init__()
        self.model = GINN(JEB,
                          feature_expansion = hparams_feature_expansion ,
                          Model_hyperparameters = hparams_model)

    def forward(self,coords):
        SDF = self.model(coords)
        return SDF


class density_GINN(torch.nn.Module):
    def __init__(self,
                 hparams_model,
                 hparams_feature_expansion,
                 density_alpha,
                 volume_ratio):

        super().__init__()
        self.volume_ratio = volume_ratio
        self.density_alpha = density_alpha
        self.model = GINN(JEB,
                          feature_expansion = hparams_feature_expansion ,
                          Model_hyperparameters = hparams_model)

    def forward(self,coords):
        SDF = self.model(coords)
        v = torch.as_tensor(self.volume_ratio, dtype=SDF.dtype, device=SDF.device)
        offset = torch.log(v / (1.0 - v))
        rho = torch.sigmoid(self.density_alpha * SDF + offset)
        return rho.clamp(0.0, 1.0)


class PINN(torch.nn.Module):
    def __init__(self,
                 hparams_model,
                 hparams_feature_expansion,
                 mollifier_alpha):

        super().__init__()
        self.mollifier_alpha = mollifier_alpha

        self.model = GINN(JEB,
                          feature_expansion = hparams_feature_expansion ,
                          Model_hyperparameters = hparams_model)

    @staticmethod
    def enforce_dirichlet_BC(alpha, u, coords):

        device = u.device if isinstance(coords, torch.Tensor) else None
        radius    = JEB.bolt_interface_radius
        depth     = JEB.bolt_interface_depth
        centroid1 = JEB.bolt1_centroid
        centroid2 = JEB.bolt2_centroid
        centroid3 = JEB.bolt3_centroid
        centroid4 = JEB.bolt4_centroid

        centroids = torch.tensor([
            centroid1[0:2],
            centroid2[0:2],
            centroid3[0:2],
            centroid4[0:2]
        ], dtype=torch.float32, device=device)

        x = coords[:, 0]
        y = coords[:, 1]
        xy_coords = torch.stack([x, y], dim=1)

        multiplier_total = torch.ones(coords.shape[0], device=device)

        for centroid in centroids:
            dist = torch.norm(xy_coords - centroid, dim=1)

            inside_mask = dist <= radius
            outside_mask = dist > radius

            d_inside = torch.abs(dist[inside_mask] - radius)
            d_outside = torch.abs(radius - dist[outside_mask])

            m_inside = torch.tanh(alpha * d_inside)
            m_outside = torch.tanh(alpha * d_outside)

            multiplier = torch.ones_like(dist)
            multiplier[inside_mask] = m_inside
            multiplier[outside_mask] = m_outside

            multiplier_total *= multiplier

        return u * multiplier_total.unsqueeze(1)


    def forward(self,coords):
        u = self.model(coords)
        return self.enforce_dirichlet_BC(self.mollifier_alpha,u,coords)

In [None]:
class GINN_Losses:
    def __init__(self,
                 GINN_model,
                 test_case,
                 GINN_hparams,
                 PH,
                 boundary_sampler,
                 enforce_density):
        super().__init__()
        self.device = device
        self.GINN_model = GINN_model.to(self.device)
        self.GINN_hparams = GINN_hparams
        self.boundary_sampler = boundary_sampler
        self.test_case = test_case
        self.enforce_density = enforce_density
        self.dim = test_case.dim
        self.envelope_extension_factor = self.GINN_hparams['envelope_extension_factor']
        self.num_points_connectivity_loss = self.GINN_hparams['num_points_connectivity_loss']
        self.num_points_interface_loss = self.GINN_hparams['num_points_interface_loss']
        self.num_points_normals_loss = self.GINN_hparams['num_points_normals_loss']
        self.num_points_envelope_loss = self.GINN_hparams['num_points_envelope_loss']
        self.clip_max_value = self.GINN_hparams['clip_max_value']
        self.clip_min_value = self.GINN_hparams['clip_min_value']
        self.max_curv = self.GINN_hparams['max_curv']
        self.curv_start_epoch = self.GINN_hparams['curv_start_epoch']
        self.PH = PH

    def surface_normal_loss(self):
        normals = self.test_case.interfaces.get_all_prescribed_surface_normals(
            num_points=self.num_points_normals_loss,
            type='torch_tensor',
            include_all=True
        )
        points = normals['points'].to(self.device)
        neumann_points          = points[normals['neumann_idx']]
        dirichlet_points        = points[normals['dirichlet_idx']]
        inp = torch.vstack([neumann_points, dirichlet_points]).to(self.device)
        inp = inp.requires_grad_(True)

        target_surface_normals = torch.vstack([normals['neumann_normals'],
                                               normals['dirichlet_normals']]).to(self.device)

        SDF_values = self.GINN_model(inp).view(-1)
        predicted_surface_normals = torch.autograd.grad(
            inputs=inp,
            outputs=SDF_values,
            grad_outputs=torch.ones_like(SDF_values),
            create_graph=True,
            only_inputs=True
        )[0]
        predicted_surface_normals = F.normalize(predicted_surface_normals, p=2, dim=1)
        loss = F.mse_loss(predicted_surface_normals, target_surface_normals)
        return loss

    def prohibited_region_loss(self, coords):
        inside_points, inside_mask = self.test_case.interfaces.is_inside_prohibited_region(coords)
        SDF = self.GINN_model(coords).squeeze()
        SDF_inside = SDF[inside_mask].view(-1)
        violation_mask = (SDF_inside < 0)
        SDF_violations = SDF_inside[violation_mask]
        if SDF_violations.numel() > 0:
            loss = torch.square(SDF_violations).sum()
        else:
            loss = torch.tensor(0.0, device=SDF.device, dtype=torch.float32, requires_grad=True)
        return loss

    def prescribed_thickness_loss(self, coords):
        inside_points, inside_mask = self.test_case.interfaces.is_inside_interface_thickness(coords)
        SDF = self.GINN_model(coords).squeeze()
        SDF_inside = SDF[inside_mask].view(-1)
        violation_mask = (SDF_inside > 0)
        SDF_violations = SDF_inside[violation_mask]
        if SDF_violations.numel() > 0:
            prescribed_thickness_loss = torch.square(SDF_violations).sum()
        else:
            prescribed_thickness_loss = torch.tensor(0.0, device=SDF.device, dtype=torch.float32, requires_grad=True)
        return prescribed_thickness_loss

    def interface_loss(self):
        coords = self.test_case.interfaces.sample_points_from_all_interfaces(
            self.num_points_interface_loss,
            random_seed=None,
            output_type='torch_tensor'
        )
        input = coords.detach().to(self.device)
        sdf_values = self.GINN_model(input).view(-1,1)
        target_sdf_values = torch.zeros_like(sdf_values)
        interface_loss = F.mse_loss(sdf_values, target_sdf_values)
        return interface_loss

    def eikonal_loss(self, coords):
        coords.requires_grad_(True)
        SDF = self.GINN_model(coords).squeeze()
        SDF_grad = torch.autograd.grad(outputs=SDF, inputs=coords, grad_outputs=torch.ones_like(SDF), create_graph=True)[0]
        SDF_grad_norm = torch.norm(SDF_grad, dim=1)
        eikonal_loss = torch.mean((SDF_grad_norm - 1) ** 2)
        return eikonal_loss

    def design_envelope_loss(self):
        design_envelope = self.test_case.domain
        x_min_domain = design_envelope[0]; x_max_domain = design_envelope[1]
        y_min_domain = design_envelope[2]; y_max_domain = design_envelope[3]
        z_min_domain = design_envelope[4]; z_max_domain = design_envelope[5]

        extension_factor = self.GINN_hparams['envelope_extension_factor']
        x_min_extended = x_min_domain - extension_factor * (x_max_domain - x_min_domain)
        x_max_extended = x_max_domain + extension_factor * (x_max_domain - x_min_domain)
        y_min_extended = y_min_domain - extension_factor * (y_max_domain - y_min_domain)
        y_max_extended = y_max_domain + extension_factor * (y_max_domain - y_min_domain)
        z_min_extended = z_min_domain - extension_factor * (z_max_domain - z_min_domain)
        z_max_extended = z_max_domain + extension_factor * (z_max_domain - z_min_domain)

        extended_domain = np.array([x_min_extended, x_max_extended, y_min_extended, y_max_extended, z_min_extended, z_max_extended])

        point_sampler = Point_Sampler(extended_domain,
                                      num_points_domain=self.num_points_envelope_loss)
        points = next(point_sampler).to(self.device)
        points.requires_grad_(True)

        c1 = torch.logical_and(points[:, 0] >= x_min_domain, points[:, 0] <= x_max_domain)
        c2 = torch.logical_and(points[:, 1] >= y_min_domain, points[:, 1] <= y_max_domain)
        c3 = torch.logical_and(points[:, 2] >= z_min_domain, points[:, 2] <= z_max_domain)
        mask = c1 & c2 & c3

        points_outside_envelope = points[~mask]
        if points_outside_envelope.shape[0] == 0:
            raise ValueError("No points sampled outside the design envelope. Increase points or extension.")

        sdf_values = self.GINN_model(points_outside_envelope).view(-1)
        violation_mask = (sdf_values <= 0)
        SDF_constraint_violations = sdf_values[violation_mask]
        if SDF_constraint_violations.numel() > 0:
            envelope_loss = torch.square(SDF_constraint_violations).sum()
        else:
            envelope_loss = torch.tensor(0.0, device=sdf_values.device, dtype=torch.float32, requires_grad=True)

        return envelope_loss


    def smoothness_loss(self, epoch) -> torch.Tensor:
        if epoch < self.curv_start_epoch:
            return torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=True)

        surface_points, weights = self.boundary_sampler.get_surface_pts()
        if surface_points is None or surface_points.numel() == 0:
            print('Returning Zero Loss - No Surface Points Found')
            return torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=True)

        pts = surface_points.to(device)
        if pts.dim() != 2:
            raise ValueError(f"surface_points must be [B, D], got {pts.shape}")
        w = weights.to(device)
        if w.dim() == 2 and w.size(-1) == 1:
            w = w.squeeze(-1)
        if w.dim() != 1 or w.shape[0] != pts.shape[0]:
            raise ValueError(f"weights must be [B] or [B,1] matching points. Got {weights.shape}")

        B, D = pts.shape
        pts = pts.clone().detach().requires_grad_(True)
        sdf = self.GINN_model(pts).view(-1)
        grad_outputs = torch.ones_like(sdf)

        df_dx = torch.autograd.grad(
            outputs=sdf, inputs=pts, grad_outputs=grad_outputs,
            create_graph=True, retain_graph=True, only_inputs=True,
        )[0]  # [B, D]

        H_rows = []
        for d in range(D):
            g_comp = df_dx[:, d]
            Hg = torch.autograd.grad(
                outputs=g_comp, inputs=pts, grad_outputs=torch.ones_like(g_comp),
                create_graph=True, retain_graph=True, only_inputs=True,
            )[0]
            H_rows.append(Hg.unsqueeze(1))
        H = torch.cat(H_rows, dim=1)  # [B, D, D]

        grad_sq = (df_dx.square()).sum(dim=1)
        F4 = torch.clamp(grad_sq.square(), min=1.0e-15)

        top = torch.cat([H, df_dx.unsqueeze(2)], dim=2)
        bottom = torch.cat([df_dx.unsqueeze(1), torch.zeros(B, 1, 1, device=device, dtype=H.dtype)], dim=2)
        aug = torch.cat([top, bottom], dim=1)
        det_aug = torch.det(aug)
        gauss_curvatures = (-1.0) / F4 * det_aug

        FHFT = torch.einsum('bi,bij,bj->b', df_dx, H, df_dx)
        trH = torch.einsum('bii->b', H)
        N = torch.clamp(grad_sq.sqrt(), min=1.0e-5)
        mean_curvatures = -(FHFT - (N.pow(2) * trH)) / (2.0 * N.pow(3))

        E_strain = (2.0 * mean_curvatures).pow(2) - 2.0 * gauss_curvatures
        E_strain = torch.clamp(E_strain, min=self.clip_min_value, max=self.clip_max_value)
        E_strain = E_strain * w
        total = E_strain.sum()

        zero = torch.tensor(0.0, device=device, dtype=torch.float32)
        loss = torch.maximum(zero, total - float(self.max_curv))
        return loss

    def connectivity_loss(self) -> torch.Tensor:
        return self.PH.connectedness_loss()

    def holes_loss(self):
        return self.PH.holes_loss()


    def design_envelope_loss_from_density(self,
                                          density_model: torch.nn.Module,
                                          iso_level: float = 0.5) -> torch.Tensor:

        rho_threshold = iso_level

        design_envelope = self.test_case.domain
        x_min_domain = design_envelope[0]; x_max_domain = design_envelope[1]
        y_min_domain = design_envelope[2]; y_max_domain = design_envelope[3]
        z_min_domain = design_envelope[4]; z_max_domain = design_envelope[5]

        extension_factor = self.envelope_extension_factor
        x_min_extended = x_min_domain - extension_factor * (x_max_domain - x_min_domain)
        x_max_extended = x_max_domain + extension_factor * (x_max_domain - x_min_domain)
        y_min_extended = y_min_domain - extension_factor * (y_max_domain - y_min_domain)
        y_max_extended = y_max_domain + extension_factor * (y_max_domain - y_min_domain)
        z_min_extended = z_min_domain - extension_factor * (z_max_domain - z_min_domain)
        z_max_extended = z_max_domain + extension_factor * (z_max_domain - z_min_domain)

        extended_domain = np.array([x_min_extended, x_max_extended,
                                    y_min_extended, y_max_extended,
                                    z_min_extended, z_max_extended],
                                   dtype=np.float32)

        point_sampler = Point_Sampler(
            extended_domain,
            num_points_domain=self.num_points_envelope_loss,
            num_points_interface=0
        )
        points = next(point_sampler).to(self.device)   # [N,3]

        c1 = torch.logical_and(points[:, 0] >= x_min_domain, points[:, 0] <= x_max_domain)
        c2 = torch.logical_and(points[:, 1] >= y_min_domain, points[:, 1] <= y_max_domain)
        c3 = torch.logical_and(points[:, 2] >= z_min_domain, points[:, 2] <= z_max_domain)
        mask_inside = c1 & c2 & c3

        points_outside = points[~mask_inside]
        if points_outside.shape[0] == 0:
            return torch.tensor(0.0, device=self.device, dtype=torch.float32, requires_grad=True)

        rho_out = density_model(points_outside).view(-1)
        rho_out = torch.clamp(rho_out, 0.0, 1.0)

        violation_mask = (rho_out > rho_threshold)
        rho_viol = rho_out[violation_mask]

        if rho_viol.numel() > 0:
            env_loss = torch.square(rho_viol-iso_level).sum()
        else:
            env_loss = torch.tensor(0.0, device=self.device, dtype=torch.float32, requires_grad=True)

        return env_loss


    def connectivity_loss_from_density(self,
                                   density_model: torch.nn.Module,
                                   iso_level: float = 0.5) -> torch.Tensor:
        return self.PH.connectedness_loss_from_density(density_model, iso_level)



In [None]:
class PINN_losses(Properties):
    def __init__(self,
                 u_model,
                 v_model,
                 w_model,
                 GINN_model,
                 n_opt_samples,
                 training_hparams,
                 enforce_density):
        super().__init__(test_case=JEB)
        self.u_model = u_model.to(self.device)
        self.v_model = v_model.to(self.device)
        self.w_model = w_model.to(self.device)
        self.GINN_model = GINN_model.to(self.device)
        self.n_opt_samples = n_opt_samples
        self.enforce_density = enforce_density
        self.dirichlet_pts = training_hparams['dirichlet_pts']
        self.num_neumann_pts = training_hparams['num_neumann_points']
        self.density_exponent = training_hparams['density_exponent']


    def ritz_loss(self,x):

        # External Work
        neumann_pts = self.interfaces.sample_points_on_neumann_boundary(self.num_neumann_pts, 'vertical', 'torch_tensor').to(self.device)
        neumann_pts.requires_grad_(True)

        R = JEB.pinn_interface_radius
        width = JEB.pinn_interface_width
        arc_area = 2 * np.pi * R * width
        ds = arc_area / neumann_pts.shape[0]


        traction_z = self.force_vector[2] / arc_area
        prescribed_traction = torch.zeros_like(neumann_pts)
        prescribed_traction[:, 2] = traction_z

        u_neu = self.u_model(neumann_pts).squeeze(-1)
        v_neu = self.v_model(neumann_pts).squeeze(-1)
        w_neu = self.w_model(neumann_pts).squeeze(-1)
        displacements_neumann = torch.stack([u_neu, v_neu, w_neu], dim=1)

        work_density = torch.sum(prescribed_traction * displacements_neumann, dim=1)  # shape (N,)
        external_energy = torch.sum(work_density * ds)

        #Internal Strain Energy
        densities = self.GINN_model(x)
        densities, _ = self.enforce_density.apply(x,densities, None, self.n_opt_samples, JEB.domain)

        coords = x.detach().clone().requires_grad_(True).to(self.device)

        u = self.u_model(coords)
        v = self.v_model(coords)
        w = self.w_model(coords)

        grad_u = torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
        grad_v = torch.autograd.grad(v, coords, grad_outputs=torch.ones_like(v), create_graph=True, retain_graph=True)[0]
        grad_w = torch.autograd.grad(w, coords, grad_outputs=torch.ones_like(w), create_graph=True, retain_graph=True)[0]

        epsilon_11 = grad_u[:, 0]  # du/dx
        epsilon_22 = grad_v[:, 1]  # dv/dy
        epsilon_33 = grad_w[:, 2]  # dw/dz

        epsilon_12 = 0.5 * (grad_u[:, 1] + grad_v[:, 0])  # shear xy
        epsilon_13 = 0.5 * (grad_u[:, 2] + grad_w[:, 0])  # shear xz
        epsilon_23 = 0.5 * (grad_v[:, 2] + grad_w[:, 1])  # shear yz

        trace_epsilon = epsilon_11 + epsilon_22 + epsilon_33

        densities = densities.squeeze()
        lame_lambda = self.lame_lambda * torch.ones_like(epsilon_11)
        lame_lambda = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_lambda

        lame_mu = self.lame_mu * torch.ones_like(epsilon_11)
        lame_mu = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_mu

        sigma_11 = 2 * lame_mu * epsilon_11 + lame_lambda * trace_epsilon
        sigma_22 = 2 * lame_mu * epsilon_22 + lame_lambda * trace_epsilon
        sigma_33 = 2 * lame_mu * epsilon_33 + lame_lambda * trace_epsilon
        sigma_12 = 2 * lame_mu * epsilon_12
        sigma_13 = 2 * lame_mu * epsilon_13
        sigma_23 = 2 * lame_mu * epsilon_23

        strain_energy_density = 0.5 * (
            sigma_11 * epsilon_11 +
            sigma_22 * epsilon_22 +
            sigma_33 * epsilon_33 +
            2 * sigma_12 * epsilon_12 +
            2 * sigma_13 * epsilon_13 +
            2 * sigma_23 * epsilon_23
        )

        # Approximate domain volume using volume fraction field
        internal_energy = JEB.domain_volume * strain_energy_density.mean()

        #Total Potential Energy
        energy = internal_energy - external_energy

        return energy

In [None]:
class topology_optimization(Properties):
    def __init__(self,
                 u_model,
                 v_model,
                 w_model,
                 GINN_model,
                 n_opt_samples,
                 training_hparams,
                 enforce_density):
        super().__init__(test_case=JEB)
        self.u_model = u_model.to(self.device)
        self.v_model = v_model.to(self.device)
        self.w_model = w_model.to(self.device)
        self.GINN_model = GINN_model.to(self.device)
        self.n_opt_samples = n_opt_samples
        self.enforce_density = enforce_density
        self.density_exponent = training_hparams['density_exponent']

    def compute_sensitivities(self,coords):
        self.u_model.eval(); self.v_model.eval(); self.GINN_model.eval()
        with torch.enable_grad():
            densities = self.GINN_model(coords).detach()

            # Enforce hard density constraints
            densities, _ = self.enforce_density.apply(coords,densities, None,self.n_opt_samples,JEB.domain)
            densities.requires_grad_(True)

            coords.requires_grad_(True)

            u = self.u_model(coords)
            v = self.v_model(coords)
            w = self.w_model(coords)

            grad_u = torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
            grad_v = torch.autograd.grad(v, coords, grad_outputs=torch.ones_like(v), create_graph=True, retain_graph=True)[0]
            grad_w = torch.autograd.grad(w, coords, grad_outputs=torch.ones_like(w), create_graph=True, retain_graph=True)[0]

            epsilon_11 = grad_u[:, 0]  # du/dx
            epsilon_22 = grad_v[:, 1]  # dv/dy
            epsilon_33 = grad_w[:, 2]  # dw/dz

            epsilon_12 = 0.5 * (grad_u[:, 1] + grad_v[:, 0])  # shear xy
            epsilon_13 = 0.5 * (grad_u[:, 2] + grad_w[:, 0])  # shear xz
            epsilon_23 = 0.5 * (grad_v[:, 2] + grad_w[:, 1])  # shear yz

            trace_epsilon = epsilon_11 + epsilon_22 + epsilon_33

            densities = densities.squeeze()
            lame_lambda = self.lame_lambda * torch.ones_like(epsilon_11)
            lame_lambda = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_lambda

            lame_mu = self.lame_mu * torch.ones_like(epsilon_11)
            lame_mu = densities.clamp(0.0, 1.0).pow(self.density_exponent)*lame_mu

            sigma_11 = 2 * lame_mu * epsilon_11 + lame_lambda * trace_epsilon
            sigma_22 = 2 * lame_mu * epsilon_22 + lame_lambda * trace_epsilon
            sigma_33 = 2 * lame_mu * epsilon_33 + lame_lambda * trace_epsilon
            sigma_12 = 2 * lame_mu * epsilon_12
            sigma_13 = 2 * lame_mu * epsilon_13
            sigma_23 = 2 * lame_mu * epsilon_23

            strain_energy_density = 0.5 * (
                sigma_11 * epsilon_11 +
                sigma_22 * epsilon_22 +
                sigma_33 * epsilon_33 +
                2 * sigma_12 * epsilon_12 +
                2 * sigma_13 * epsilon_13 +
                2 * sigma_23 * epsilon_23
            )

            internal_energy = JEB.domain_volume * strain_energy_density.mean()

            loss = internal_energy

            (d_rho,) = torch.autograd.grad(loss,densities, retain_graph=False, create_graph=False, allow_unused=False)
            sensitivities = -d_rho  # match TF custom-gradient sign from NTopo implementation 

        return densities.detach(), sensitivities.detach()

    @staticmethod
    def pad_constant_3d_count(nx, ny, nz, rep: int, device, dtype):
        """Helper for sensitivity filter - following method from https://github.com/JonasZehn/ntopo/tree/main/ntopo """
        return torch.ones((1,1,nx,ny,nz), device=device, dtype=dtype)

    def apply_sensitivity_filter_3d(self,
                                    coords: torch.Tensor,
                                    old_densities: torch.Tensor,
                                    sensitivities: torch.Tensor,
                                    n_samples: Tuple[int,int,int],
                                    domain: np.ndarray,
                                    radius: float) -> torch.Tensor:
        """
        PyTorch analogue of ntopo.apply_sensitivity_filter_3d:
        - adopted from https://github.com/JonasZehn/ntopo/tree/main/ntopo .
        """
        gamma = 1e-3
        nx, ny, nz = int(n_samples[0]), int(n_samples[1]), int(n_samples[2])
        N = nx * ny * nz
        assert coords.shape[0] == N

        dx = (domain[1] - domain[0]) / nx
        dy = (domain[3] - domain[2]) / ny
        dz = (domain[5] - domain[4]) / nz

        device = coords.device
        dtype  = old_densities.dtype

        fsize = 2*round(radius) + 1
        rep = fsize // 2
        radius_space = radius * dx

        rng = torch.arange(-rep, rep+1, device=device, dtype=dtype)
        IX, IY, IZ = torch.meshgrid(rng, rng, rng, indexing='ij')
        D = torch.sqrt( (IX*dx)**2 + (IY*dy)**2 + (IZ*dz)**2 + 1e-35 )
        K = torch.clamp(radius_space - D, min=0.0)  # [f,f,f]
        K = K.view(1, 1, fsize, fsize, fsize)

        dens = old_densities.view(nx, ny, nz).unsqueeze(0).unsqueeze(0)
        sens = sensitivities.view(nx, ny, nz).unsqueeze(0).unsqueeze(0)

        num = F.conv3d(dens * sens, K, stride=1, padding=rep)
        ones = self.pad_constant_3d_count(nx, ny, nz, rep, device, dtype)
        sum_Hei = F.conv3d(ones, K, stride=1, padding=rep)

        rho_center = torch.clamp(old_densities.view(1,1,nx,ny,nz), min=gamma)
        div = rho_center * sum_Hei + 1e-35

        grads = (num / div).view(-1, 1)  # [N,1]
        return grads


    def apply_sensitivity_filter(self,
                                 coords: torch.Tensor,
                                 old_densities: torch.Tensor,
                                 sensitivities: torch.Tensor,
                                 n_samples,
                                 domain: np.ndarray,
                                 radius: float) -> torch.Tensor:
        return self.apply_sensitivity_filter_3d(coords, old_densities, sensitivities, n_samples, domain, radius)


    @torch.no_grad()
    def compute_target_densities(self,
                                old_densities_list,
                                sensitivities_list,
                                sample_volume: float,
                                target_volume: float,
                                max_move: float,
                                damping_parameter: float):
        """ Compute target densities using OC update rule: derived from https://github.com/JonasZehn/ntopo/tree/main/ntopo """
        total = sum([odi.numel() for odi in old_densities_list])
        dv = sample_volume / float(total)

        lb_list = [torch.clamp(odi - max_move, 0.0, 1.0) for odi in old_densities_list]
        ub_list = [torch.clamp(odi + max_move, 0.0, 1.0) for odi in old_densities_list]

        def targets_for_lambda(lmbd: float):
            targets = []
            flat_all = []
            for odi, s, lb, ub in zip(old_densities_list, sensitivities_list, lb_list, ub_list):
                Bi = s / (-(dv * lmbd) + 1e-20)
                tgt = odi * torch.pow(torch.clamp(Bi, min=1e-20), damping_parameter)
                tgt = torch.maximum(lb, torch.minimum(ub, tgt))
                tgt = torch.clamp(tgt, 0.0, 1.0)
                targets.append(tgt)
                flat_all.append(tgt.reshape(-1))
            vol = sample_volume * torch.mean(torch.cat(flat_all, dim=0))
            return vol, targets

        lam_lo, lam_hi = 0.0, 1e9
        for _ in range(60):
            lam_mid = 0.5 * (lam_lo + lam_hi)
            vol_mid, _ = targets_for_lambda(lam_mid)
            if (lam_hi - lam_lo) / (lam_hi + lam_lo + 1e-12) < 1e-3:
                break
            if vol_mid > target_volume:
                lam_lo = lam_mid
            else:
                lam_hi = lam_mid

        _, targets = targets_for_lambda(0.5 * (lam_lo + lam_hi))
        return targets

In [None]:
class model_training:
    def __init__(self,
                 u_model,
                 v_model,
                 w_model,
                 density_GINN_model,
                 SDF_GINN_model,
                 training_hparams,
                 topo_hparams,
                 constraint_hparams,
                 GINN_hparams,
                 loss_weight_hparams,
                 test_case,
                 ):

        super().__init__()


        # ---------------- Device ----------------
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else
            "mps" if torch.backends.mps.is_available() else
            "cpu"
        )

        # ---------------- Models ----------------
        self.u_model = u_model.to(self.device)
        self.v_model = v_model.to(self.device)
        self.w_model = w_model.to(self.device)
        self.density_GINN_model = density_GINN_model.to(self.device)
        self.SDF_GINN_model = SDF_GINN_model.to(self.device)

        # ---------------- Hparams ----------------
        self.topo_hparams = topo_hparams
        self.training_hparams = training_hparams
        self.test_case = test_case
        self.GINN_hparams = GINN_hparams
        self.num_points = training_hparams['total_sample_points']
        self.batch_size = training_hparams['batch_size']
        self.save_path = topo_hparams['save_path']
        self.volume_ratio = topo_hparams.get('volume_ratio', None)
        self.lr_PINN = topo_hparams['lr_PINN']
        self.lr_GINN = topo_hparams['lr_GINN']
        self.save_interval = topo_hparams['save_interval']
        self.filter_radius = topo_hparams['filter_radius']
        self.n_opt_iterations = topo_hparams['n_opt_iterations']
        self.n_sim_iterations = topo_hparams['n_sim_iterations']
        self.n_pre_training_iterations_PINN = topo_hparams['n_pre_training_iterations_PINN']
        self.n_pre_training_iterations_GINN = topo_hparams['n_pre_training_iterations_GINN']
        self.n_pre_training_iterations_density_GINN = topo_hparams['n_pre_training_iterations_density_GINN']
        self.n_opt_batches = topo_hparams['n_opt_batches']
        self.seed = topo_hparams['seed']
        self.refine_study = training_hparams["refine_study"]
        self.num_points_envelope_loss = GINN_hparams['num_points_envelope_loss']
        self.num_points_connectivity_loss = GINN_hparams['num_points_connectivity_loss']
        self.envelope_extension_factor = GINN_hparams['envelope_extension_factor']
        self.plot_threshold = float(topo_hparams.get('rho_treshold', 0.25))
        self.n_eval_points_plot = int(topo_hparams.get('n_eval_points_plot', 400_000))
        self.max_batch_plot = int(topo_hparams.get('max_batch_plot', 120_000))
        self.csv_ginn_loss_path = os.path.join(self.save_path, "ginn_losses.csv")
        self.csv_topo_loss_path = os.path.join(self.save_path, "topo_density_losses.csv")
        self.csv_opt_metrics_path = os.path.join(self.save_path, "topo_metrics.csv")
        self.csv_opt_loss_path = os.path.join(self.save_path, "optimization_losses.csv")
        self.csv_eval_metrics_path = os.path.join(self.save_path, "evaluation_metrics.csv")
        self.eval_metrics_enable = bool(self.topo_hparams.get("eval_metrics_enable", True))
        self.eval_grid_n = int(self.topo_hparams.get("eval_metrics_grid_n", 64))
        self.eval_batch = int(self.topo_hparams.get("eval_metrics_batch", 200_000))
        self.eval_boundary_n = int(self.topo_hparams.get("eval_metrics_boundary_n", 12_000))
        self.eval_interface_n = int(self.topo_hparams.get("eval_metrics_interface_n", 8_000))
        self.eval_max_boundary_pts = int(self.topo_hparams.get("eval_metrics_max_boundary_pts", 20_000))
        self.eval_cdist_chunk = int(self.topo_hparams.get("eval_metrics_cdist_chunk", 4_096))


        self.enforce_density = Density_Constraints(
            constraint_hparams,
            JEB,
            JEB.interfaces.is_inside_interface_thickness,
            JEB.interfaces.is_inside_prohibited_region
        )

        # ---------------- ALM setup for first optimization stage ----------------
        self.scalar_keys = [
            'Objective',
            'Envelope Loss',
            'Connectivity Loss',
            'Prescribed Normals Loss',
            'Prescribed Thickness Loss',
            'Prohibited Region Loss',
            'Smoothness Loss',
            'Holes Loss',
        ]

        lambda_init_geom = {
            'Objective': 1.0,
            'Envelope Loss':             GINN_hparams.get('envelope_loss_weight', 1.0),
            'Connectivity Loss':         GINN_hparams.get('connectivity_loss_weight', 100.0),
            'Prescribed Normals Loss':   GINN_hparams.get('prescribed_normals_loss_weight', 1.0),
            'Prescribed Thickness Loss': GINN_hparams.get('prescribed_thickness_loss_weight', 1.0),
            'Prohibited Region Loss':    GINN_hparams.get('prohibited_region_loss_weight', 1.0),
            'Smoothness Loss':           GINN_hparams.get('smoothness_loss_weight', 1e-4),
            'Holes Loss':                GINN_hparams.get('holes_loss_weight', 100.0),
        }

        self.alm = ALM(
            loss_keys=self.scalar_keys,
            objective_key='Objective',
            lambda_dict=lambda_init_geom,
            alpha=loss_weight_hparams['alpha'],
            gamma=loss_weight_hparams['gamma'],
            epsilon=loss_weight_hparams['epsilon'],
            device=self.device
        )

        # ---------------- ALM setup for second optimization stage  ----------------
        self.scalar_keys_rho = [
            'Topo Objective',
            'Density Envelope Loss',
            'Density Connectivity Loss'
        ]

        lambda_init_rho = {
            'Topo Objective':            1.0,
            'Density Envelope Loss':     GINN_hparams.get('envelope_loss_weight_density', 1.0),
            'Density Connectivity Loss': GINN_hparams.get('connectivity_loss_weight_density', 100.0),
        }

        self.alm_rho = ALM(
            loss_keys=self.scalar_keys_rho,
            objective_key='Topo Objective',
            lambda_dict=lambda_init_rho,
            alpha=loss_weight_hparams['alpha'],
            gamma=loss_weight_hparams['gamma'],
            epsilon=loss_weight_hparams['epsilon'],
            device=self.device
        )

        # ---------------- Gradient clipping ----------------
        self.grad_clipper = AutoClip(
            grad_clipping_on=training_hparams["grad_clipping_on"],
            grad_clip=training_hparams["grad_clip"],
            auto_clip_on=training_hparams["auto_clip_on"],
            auto_clip_percentile=training_hparams["auto_clip_percentile"],
            auto_clip_hist_len=training_hparams["auto_clip_hist_len"],
            auto_clip_min_len=training_hparams["auto_clip_min_len"],
        )

        # ---------------- Samplers ----------------
        self.point_sampler = Point_Sampler(
            test_case.domain,
            test_case.interfaces,
            num_points_domain=self.num_points,
            num_points_interface=0
        )

        # PH setup
        inside_env_sdf = make_inside_envelope_fn_from_domain(
            domain=self.test_case.domain, nx=self.test_case.dim, device=self.device
        )
        self.PH = PH(
            nx=self.test_case.dim,
            bounds=self.test_case.domain,
            model=self.SDF_GINN_model,
            n_grid_points=96,
            iso_level=0.0,
            target_betti=[1, 0, 0],
            maxdim=1,
            is_density=False,
            inside_envelope_fn=inside_env_sdf,
            group_size_fwd_no_grad=32768,
            add_frame=True,
            hole_level=0.06,
            test_case=self.test_case
        )

        # boundary sampler
        interface_points = self.test_case.interfaces.sample_points_from_all_interfaces(
            20000, output_type='torch_tensor'
        )
        self.boundary_sampler = Boundary_Sampler(
            dim=self.test_case.dim,
            bounds=self.test_case.domain,
            model=self.SDF_GINN_model,
            x_interface=interface_points,
            n_points_surface=20000,
            interface_cutoff=0.05
        )

        # PH setup density
        inside_env_density = make_inside_envelope_fn_from_domain(
            domain=self.test_case.domain, nx=self.test_case.dim, device=self.device
        )
        self.PH_density = PH(
            nx=self.test_case.dim,
            bounds=self.test_case.domain,
            model=self.density_GINN_model,
            n_grid_points=96,
            iso_level=float(self.plot_threshold),
            target_betti=[1, 0, 0],
            maxdim=1,
            is_density=True,
            inside_envelope_fn=inside_env_density,
            group_size_fwd_no_grad=32768,
            add_frame=True,
            hole_level=0.06,
            test_case=self.test_case
        )

        self.sigma_allow = float(topo_hparams.get('sigma_allow', 2000.0))
        self.stress_tol = float(topo_hparams.get('stress_tol', 0.02))
        self.ks_rho = float(topo_hparams.get('ks_rho', 80.0))
        self.ks_size_correction = bool(topo_hparams.get('ks_size_correction', False))
        self.stress_metric = str(topo_hparams.get('stress_metric', 'KS'))
        self.stress_percentile = float(topo_hparams.get('stress_percentile', 0.995))
        self.vol_frac_min = float(topo_hparams.get('vol_frac_min', 0.02))
        self.vol_frac_max = float(topo_hparams.get('vol_frac_max', 0.98))
        self.vol_step_min = float(topo_hparams.get('vol_step_min', 0.005))
        self.vol_step_max = float(topo_hparams.get('vol_step_max', 0.05))
        self.alpha_decrease = float(topo_hparams.get('alpha_decrease', 0.50))
        self.beta_increase = float(topo_hparams.get('beta_increase', 0.25))

        self.sigma_ema = None
        self.sigma_ema_alpha = 0.2

        if getattr(self.density_GINN_model, "volume_ratio", None) is None:
            self.density_GINN_model.volume_ratio = 0.5


        self.hist_iters = []
        self.hist_compliance = []
        self.hist_volume = []
        self.hist_sigma_metric = []
        self.hist_sigma_max = []
        self.ginn_hist_iters = []
        self.ginn_hist_eik = []
        self.ginn_hist_env = []
        self.ginn_hist_connect = []
        self.ginn_hist_hole = []
        self.ginn_hist_int = []
        self.ginn_hist_norm = []
        self.ginn_hist_thick = []
        self.ginn_hist_prohib = []
        self.ginn_hist_smooth = []
        self.ginn_hist_total = []
        self.topo_hist_iters = []
        self.topo_hist_topo = []
        self.topo_hist_env_rho = []
        self.topo_hist_conn_rho = []
        self.topo_hist_total = []
        self.topo_step_counter = 0




    # ----- helpers -----------------------

    @staticmethod
    def loss_ramp_up(epoch: int, start_epoch: int, ramp_epochs: int) -> float:
        if epoch < start_epoch:
            return 0.0
        return float(min(1.0, (epoch - start_epoch + 1) / max(1, ramp_epochs)))

    _compute_eval_metrics_3d = compute_eval_metrics_3d
    _log_eval_metrics_csv = log_eval_metrics_csv
    _log_ginn_losses_csv = log_ginn_losses_csv
    _log_topo_density_losses_csv = log_topo_density_losses_csv
    _log_optimization_metrics_csv = log_optimization_metrics_csv
    _log_optimization_losses_csv = log_optimization_losses_csv
    _maybe_log_eval_metrics = maybe_log_eval_metrics
    _save_training_curves = save_training_curves
    _save_ginn_training_curves = save_ginn_training_curves
    _save_topology_density_curves = save_topology_density_curves


    def _compute_losses(self, coords, epoch):

        geometry = GINN_Losses(
            self.SDF_GINN_model,
            self.test_case,
            self.GINN_hparams,
            self.PH,
            self.boundary_sampler,
            self.enforce_density
        )

        w = self.GINN_hparams
        ph_ramp = self.loss_ramp_up(epoch, start_epoch=0, ramp_epochs=550)

        # weighted components
        eik     = geometry.eikonal_loss(coords)              * w['eikonal_loss_weight']
        intr    = geometry.interface_loss()                  * w['interface_loss_weight']
        env     = geometry.design_envelope_loss()            * w['envelope_loss_weight']
        conn    = geometry.connectivity_loss()               * (w['connectivity_loss_weight'] * ph_ramp)
        norm    = geometry.surface_normal_loss()             * w['prescribed_normals_loss_weight']
        thick   = geometry.prescribed_thickness_loss(coords) * w['prescribed_thickness_loss_weight']
        prohib  = geometry.prohibited_region_loss(coords)    * w['prohibited_region_loss_weight']

        smooth0 = geometry.smoothness_loss(epoch)            * w['smoothness_loss_weight']
        smooth  = self.loss_ramp_up(epoch, w['curv_start_epoch'], w['curv_ramp_epochs']) * smooth0

        holes   = geometry.holes_loss()                      * (w['holes_loss_weight'] * ph_ramp)

        objective = eik + intr

        return {
            'Objective':                 torch.relu(objective),
            'Envelope Loss':             torch.relu(env),
            'Connectivity Loss':         torch.relu(conn),
            'Prescribed Normals Loss':   torch.relu(norm),
            'Prescribed Thickness Loss': torch.relu(thick),
            'Prohibited Region Loss':    torch.relu(prohib),
            'Smoothness Loss':           torch.relu(smooth),
            'Holes Loss':                torch.relu(holes),
        }

    def _compute_density_topology_ALM_losses(self, coords, target_densities):

        coords = coords.to(self.device)
        target_densities = target_densities.to(self.device)

        rho = self.density_GINN_model(coords)
        topo_obj = F.mse_loss(rho, target_densities)

        self.PH_density.invalidate_cache()

        geometry_rho = GINN_Losses(
            self.density_GINN_model,
            self.test_case,
            self.GINN_hparams,
            self.PH_density,
            self.boundary_sampler,
            self.enforce_density
        )

        env_rho = geometry_rho.design_envelope_loss_from_density(
            density_model=self.density_GINN_model,
            iso_level=float(self.plot_threshold)
        )

        conn_rho = geometry_rho.connectivity_loss_from_density(
            density_model=self.density_GINN_model,
            iso_level=float(self.plot_threshold)
        )

        losses_rho = {
            'Topo Objective':            torch.relu(topo_obj),
            'Density Envelope Loss':     torch.relu(env_rho),
            'Density Connectivity Loss': torch.relu(conn_rho),
        }

        return losses_rho, topo_obj.detach(), env_rho.detach(), conn_rho.detach()


    @staticmethod
    def _safe_clip_ratio(v, lo, hi, eps=1e-3):
        lo = max(lo, eps)
        hi = min(hi, 1.0 - eps)
        return float(np.clip(v, lo, hi))


    def _estimate_compliance_internal_energy_3d(self, n_opt_cells):

        self.u_model.eval()
        self.v_model.eval()
        self.w_model.eval()
        self.density_GINN_model.eval()

        nx, ny, nz = n_opt_cells
        xs = get_grid_centers(JEB.domain, [nx, ny, nz]).astype(np.float32)
        xt = torch.tensor(xs, dtype=torch.float32, device=self.device, requires_grad=True)

        rho = self.density_GINN_model(xt)
        rho, _ = self.enforce_density.apply(xt, rho, None, n_opt_cells, JEB.domain)
        r = rho.view(-1).clamp(0.0, 1.0)

        u = self.u_model(xt)
        v = self.v_model(xt)
        w = self.w_model(xt)

        gu = torch.autograd.grad(u, xt, grad_outputs=torch.ones_like(u), create_graph=False, retain_graph=True)[0]
        gv = torch.autograd.grad(v, xt, grad_outputs=torch.ones_like(v), create_graph=False, retain_graph=True)[0]
        gw = torch.autograd.grad(w, xt, grad_outputs=torch.ones_like(w), create_graph=False, retain_graph=False)[0]

        e11 = gu[:, 0]
        e22 = gv[:, 1]
        e33 = gw[:, 2]
        e12 = 0.5 * (gu[:, 1] + gv[:, 0])
        e13 = 0.5 * (gu[:, 2] + gw[:, 0])
        e23 = 0.5 * (gv[:, 2] + gw[:, 1])
        tr = e11 + e22 + e33

        p = float(self.training_hparams['density_exponent'])
        lam0 = Properties(JEB).lame_lambda
        mu0 = Properties(JEB).lame_mu
        lam = r.pow(p) * lam0
        mu = r.pow(p) * mu0

        s11 = 2.0 * mu * e11 + lam * tr
        s22 = 2.0 * mu * e22 + lam * tr
        s33 = 2.0 * mu * e33 + lam * tr
        s12 = 2.0 * mu * e12
        s13 = 2.0 * mu * e13
        s23 = 2.0 * mu * e23

        sed = 0.5 * (s11 * e11 + s22 * e22 + s33 * e33 + 2.0 * s12 * e12 + 2.0 * s13 * e13 + 2.0 * s23 * e23)
        internal_energy = float(JEB.domain_volume) * sed.mean()
        return float(internal_energy.detach().cpu().item())

    # ---- stress/volume measurement ------
    def _measure_stress_volume_KS_3d(self):
        """
        Binary-geometry evaluation
        """

        ps = Point_Sampler(
            JEB.domain, JEB.interfaces,
            num_points_domain=self.n_eval_points_plot, num_points_interface=0
        )
        pts_full = next(ps).to(self.device)

        prev = self.density_GINN_model.training
        self.density_GINN_model.eval()
        with torch.no_grad():
            rho = self.density_GINN_model(pts_full).view(-1)
        if prev:
            self.density_GINN_model.train()

        solid_mask = (rho >= self.plot_threshold)

        n_total = pts_full.shape[0]
        n_solid = int(solid_mask.sum().item())
        vol_abs = float(JEB.domain_volume) * (n_solid / max(1, n_total))

        if n_solid == 0:
            return 0.0, 0.0, vol_abs

        with torch.enable_grad():
            x = pts_full[solid_mask].clone().detach().requires_grad_(True)

            self.u_model.eval()
            self.v_model.eval()
            self.w_model.eval()

            u = self.u_model(x) * JEB.domain_scaling_factor
            v = self.v_model(x) * JEB.domain_scaling_factor
            w = self.w_model(x) * JEB.domain_scaling_factor

            gu = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=False, retain_graph=False)[0]
            gv = torch.autograd.grad(v, x, grad_outputs=torch.ones_like(v), create_graph=False, retain_graph=False)[0]
            gw = torch.autograd.grad(w, x, grad_outputs=torch.ones_like(w), create_graph=False, retain_graph=False)[0]

            e11 = gu[:, 0]
            e22 = gv[:, 1]
            e33 = gw[:, 2]
            e12 = 0.5 * (gu[:, 1] + gv[:, 0])
            e13 = 0.5 * (gu[:, 2] + gw[:, 0])
            e23 = 0.5 * (gv[:, 2] + gw[:, 1])
            tr = e11 + e22 + e33

            props = Properties(JEB)
            lam = props.lame_lambda.to(self.device)
            mu = props.lame_mu.to(self.device)

            s11 = 2 * mu * e11 + lam * tr
            s22 = 2 * mu * e22 + lam * tr
            s33 = 2 * mu * e33 + lam * tr
            s12 = 2 * mu * e12
            s13 = 2 * mu * e13
            s23 = 2 * mu * e23

            vm2 = 0.5 * ((s11 - s22) ** 2 + (s22 - s33) ** 2 + (s33 - s11) ** 2) + 3 * (s12 ** 2 + s13 ** 2 + s23 ** 2)
            vm = torch.sqrt(torch.clamp(vm2, min=1e-32)) * JEB.domain_scaling_factor

        s = vm.detach().cpu().numpy().astype(np.float64)
        g = s / float(self.sigma_allow) - 1.0
        gmax = float(np.max(g))
        lse = gmax + (1.0 / self.ks_rho) * float(np.log(np.sum(np.exp(self.ks_rho * (g - gmax))) + 1e-300))
        if self.ks_size_correction and s.size > 0:
            lse -= (1.0 / self.ks_rho) * float(np.log(s.size))

        ks_val = lse
        sigma_metric = float(self.sigma_allow) * (1.0 + ks_val)
        sigma_max = float(np.max(s))
        return sigma_metric, sigma_max, vol_abs

    def _measure_stress_volume_percentile_3d(self):

        ps = Point_Sampler(
            JEB.domain, JEB.interfaces,
            num_points_domain=self.n_eval_points_plot, num_points_interface=0
        )
        pts_full = next(ps).to(self.device)

        prev = self.density_GINN_model.training
        self.density_GINN_model.eval()
        with torch.no_grad():
            rho = self.density_GINN_model(pts_full).view(-1)
        if prev:
            self.density_GINN_model.train()

        solid_mask = (rho >= self.plot_threshold)
        if not torch.any(solid_mask):
            q = torch.quantile(rho, 0.9)
            solid_mask = (rho >= q)

        n_total = pts_full.shape[0]
        n_solid = int(solid_mask.sum().item())
        vol_abs = float(JEB.domain_volume) * (n_solid / max(1, n_total))

        if n_solid == 0:
            return 0.0, 0.0, vol_abs

        with torch.enable_grad():
            x = pts_full[solid_mask].clone().detach().requires_grad_(True)

            self.u_model.eval()
            self.v_model.eval()
            self.w_model.eval()

            u = self.u_model(x) * JEB.domain_scaling_factor
            v = self.v_model(x) * JEB.domain_scaling_factor
            w = self.w_model(x) * JEB.domain_scaling_factor

            gu = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=False, retain_graph=False)[0]
            gv = torch.autograd.grad(v, x, grad_outputs=torch.ones_like(v), create_graph=False, retain_graph=False)[0]
            gw = torch.autograd.grad(w, x, grad_outputs=torch.ones_like(w), create_graph=False, retain_graph=False)[0]

            e11 = gu[:, 0]
            e22 = gv[:, 1]
            e33 = gw[:, 2]
            e12 = 0.5 * (gu[:, 1] + gv[:, 0])
            e13 = 0.5 * (gu[:, 2] + gw[:, 0])
            e23 = 0.5 * (gv[:, 2] + gw[:, 1])
            tr = e11 + e22 + e33

            props = Properties(JEB)
            lam = props.lame_lambda.to(self.device)
            mu = props.lame_mu.to(self.device)

            s11 = 2 * mu * e11 + lam * tr
            s22 = 2 * mu * e22 + lam * tr
            s33 = 2 * mu * e33 + lam * tr
            s12 = 2 * mu * e12
            s13 = 2 * mu * e13
            s23 = 2 * mu * e23

            vm2 = 0.5 * ((s11 - s22) ** 2 + (s22 - s33) ** 2 + (s33 - s11) ** 2) + 3 * (s12 ** 2 + s13 ** 2 + s23 ** 2)
            vm = torch.sqrt(torch.clamp(vm2, min=1e-32)) * JEB.domain_scaling_factor

        s = vm.detach().cpu().numpy().astype(np.float64)
        q = float(np.clip(self.stress_percentile, 0.0, 1.0))
        sigma_metric = float(np.quantile(s, q)) if s.size else 0.0
        sigma_max = float(np.max(s)) if s.size else 0.0
        return sigma_metric, sigma_max, vol_abs

    def _measure_stress_volume_sigma_max_3d(self):

        _, sigma_max, vol_abs = self._measure_stress_volume_percentile_3d()
        sigma_metric = sigma_max
        return sigma_metric, sigma_max, vol_abs

    def _measure_stress_volume_3d(self):

        metric = str(getattr(self, "stress_metric", "ks")).lower()
        if metric == "ks":
            return self._measure_stress_volume_KS_3d()
        elif metric == "percentile":
            return self._measure_stress_volume_percentile_3d()
        elif metric == "sigma_max":
            return self._measure_stress_volume_sigma_max_3d()
        else:
            raise ValueError(
                f"Unknown stress metric: {metric}. "
                f"Choose between 'ks', 'percentile', and 'sigma_max'."
            )

    # ----------------------------------------------------------------------
    # training steps
    # ----------------------------------------------------------------------

    def shape_generation_step(self, optimizer, epoch):

        if self.refine_study == True: # initialize from pretrained model
            model_file = "/content/gdrive/Othercomputers/Code/JEB_topo/model-000150.pt"
            checkpoint = torch.load(model_file, map_location=self.device)
            self.SDF_GINN_model.load_state_dict(checkpoint['model_state_dict'])
        else: #normal training setup
            batch_points = next(self.point_sampler)
            dataset = TensorDataset(batch_points)
            loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

            for _, pts in enumerate(loader):
                optimizer.zero_grad(set_to_none=True)
                points = pts[0].to(self.device)
                self.PH.invalidate_cache()

                losses = self._compute_losses(points.clone().detach().requires_grad_(True), epoch)

                total_loss = self.alm.build(losses)
                total_loss.backward()

                if getattr(self, "grad_clipper", None) is not None and self.grad_clipper.grad_clip_enabled:
                    total_sq = 0.0
                    for p in self.SDF_GINN_model.parameters():
                        if p.grad is not None:
                            g = p.grad.data
                            if torch.isfinite(g).all():
                                total_sq += g.float().pow(2).sum().item()
                    grad_norm = math.sqrt(total_sq)
                    self.grad_clipper.update_gradient_norm_history(grad_norm)
                    clip_val = self.grad_clipper.get_clip_value()
                    if np.isfinite(clip_val):
                        torch.nn.utils.clip_grad_norm_(self.SDF_GINN_model.parameters(), clip_val)

                optimizer.step()
                self.alm.update(losses)

    def density_PINN_initialization_step(self, coords, density_GINN_optimizer):

        self.density_GINN_model.train()
        self.SDF_GINN_model.eval()

        sdf = self.SDF_GINN_model(coords)
        sdf = torch.sigmoid(-1000 * sdf).detach()
        topo_initial_density = torch.clamp(sdf, 0.0, 0.5)

        rho = self.density_GINN_model(coords)
        initial_density_loss = F.mse_loss(rho, topo_initial_density)

        density_GINN_optimizer.zero_grad(set_to_none=True)
        initial_density_loss.backward()
        density_GINN_optimizer.step()

    def PINN_update_step(self, coords, PINN_optimizer, n_opt_samples):

        self.u_model.train()
        self.v_model.train()
        self.w_model.train()
        self.density_GINN_model.eval()

        physics_loss = PINN_losses(
            self.u_model,
            self.v_model,
            self.w_model,
            self.density_GINN_model,
            n_opt_samples,
            self.training_hparams,
            self.enforce_density
        )

        PINN_optimizer.zero_grad(set_to_none=True)
        loss = physics_loss.ritz_loss(coords)
        loss.backward()
        PINN_optimizer.step()
        return None



    def topology_optimization_step(self, density_GINN_optimizer, coords, target_densities):

        t0 = time.perf_counter()

        self.density_GINN_model.train()
        density_GINN_optimizer.zero_grad(set_to_none=True)

        losses_rho, topo_obj_raw, env_raw, conn_raw = self._compute_density_topology_ALM_losses(
              coords, target_densities
          )

        if self.GINN_hparams['envelope_loss_weight_density']== 0:
          densities = self.density_GINN_model(coords)
          total_loss = F.mse_loss(densities, target_densities)
        else:
          total_loss = self.alm_rho.build(losses_rho)

        self.topo_step_counter += 1
        self.topo_hist_iters.append(self.topo_step_counter)

        topo_val = float(topo_obj_raw.item())
        env_val = float(env_raw.item())
        conn_val = float(conn_raw.item())

        self.topo_hist_topo.append(topo_val)
        self.topo_hist_env_rho.append(env_val)
        self.topo_hist_conn_rho.append(conn_val)

        total_raw = topo_val + env_val + conn_val
        self.topo_hist_total.append(total_raw)

        self._log_topo_density_losses_csv(
            step=self.topo_step_counter,
            topo_obj=topo_val,
            env_rho=env_val,
            conn_rho=conn_val,
            total_raw=total_raw
        )

        total_loss.backward()

        if getattr(self, "grad_clipper", None) is not None and self.grad_clipper.grad_clip_enabled:
            total_sq = 0.0
            for p in self.density_GINN_model.parameters():
                if p.grad is not None:
                    g = p.grad.data
                    if torch.isfinite(g).all():
                        total_sq += g.float().pow(2).sum().item()
            grad_norm = math.sqrt(total_sq)
            self.grad_clipper.update_gradient_norm_history(grad_norm)
            clip_val = self.grad_clipper.get_clip_value()
            if np.isfinite(clip_val):
                torch.nn.utils.clip_grad_norm_(self.density_GINN_model.parameters(), clip_val)

        density_GINN_optimizer.step()

        self.alm_rho.update(losses_rho)

        _ = time.perf_counter() - t0
        return None

    # ----------------------------------------------------------------------
    # main driver
    # ----------------------------------------------------------------------
    def generate_geometry(self):

        set_random_seed(self.seed)
        os.makedirs(self.save_path, exist_ok=True)

        DEFAULT_N_SIM_SAMPLES_2D = 300000
        DEFAULT_N_OPT_SAMPLES_2D = 300000

        n_sim_samples = get_default_sample_counts(JEB.domain, DEFAULT_N_SIM_SAMPLES_2D)
        n_opt_samples = get_default_sample_counts(JEB.domain, DEFAULT_N_OPT_SAMPLES_2D)

        PINN_model_params_list = (
            list(self.u_model.parameters()) +
            list(self.v_model.parameters()) +
            list(self.w_model.parameters())
        )
        PINN_optimizer = torch.optim.Adam(PINN_model_params_list, lr=self.lr_PINN, betas=(0.9, 0.99))
        density_GINN_optimizer = torch.optim.Adam(self.density_GINN_model.parameters(), lr=self.lr_GINN, betas=(0.8, 0.9))
        SDF_GINN_optimizer = torch.optim.Adam(self.SDF_GINN_model.parameters(), lr=self.lr_GINN)

        PINN_point_sampler = gen_samples(JEB.domain, n_sim_samples)
        GINN_point_sampler = gen_samples(JEB.domain, n_opt_samples)

        # ===================== (1) SDF-GINN pre-training =====================
        print("=== Generating Geometry (SDF pre-training) ===")

        ginn_log_interval = int(self.topo_hparams.get("ginn_log_interval", 50))

        for it in tqdm(range(self.n_pre_training_iterations_GINN), desc="GINN Training"):
            _ = next(GINN_point_sampler)  
            self.shape_generation_step(optimizer=SDF_GINN_optimizer, epoch=it)


            if (it % ginn_log_interval) == 0:
                try:
                    point_sampler = Point_Sampler(JEB.domain, JEB.interfaces, num_points_domain=400000, num_points_interface=0)
                    coords_dbg = next(point_sampler).to(self.device)
                    self.SDF_GINN_model.eval()
                    with torch.no_grad():
                        SDF = self.SDF_GINN_model(coords_dbg).squeeze(-1)

                    coords_cpu = coords_dbg.detach().cpu()
                    SDF_cpu = SDF.detach().cpu()

                    tolerance = 1e-6
                    mask = (SDF_cpu <= tolerance)
                    if mask.sum().item() == 0:
                        idx = torch.argmin(torch.abs(SDF_cpu))
                        mask[idx] = True

                    pts_in = coords_cpu[mask]
                    sdf_in = SDF_cpu[mask]

                    scatter_in = go.Scatter3d(
                        x=pts_in[:, 0].numpy(),
                        y=pts_in[:, 1].numpy(),
                        z=pts_in[:, 2].numpy(),
                        mode='markers',
                        marker=dict(
                            size=2,
                            color=sdf_in.numpy(),
                            colorscale='Reds',
                            cmin=float(SDF_cpu.min()),
                            cmax=0.0,
                            opacity=1.0
                        ),
                        name='SDF ≤ 0'
                    )

                    fig = go.Figure(data=[scatter_in])
                    fig.update_layout(
                        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z', aspectmode='data'),
                        margin=dict(l=0, r=0, b=0, t=30),
                        title="Inside points by SDF (SDF ≤ 0)"
                    )
                    fig.show()
                except Exception:
                    pass

                save_sdf_outputs_3d(
                    self.SDF_GINN_model,
                    JEB.domain,
                    n_opt_samples,
                    self.save_path,
                    save_prefix='',
                    save_postfix=f'-ginn-{it:06d}',
                    iso_level=0.0,
                    device=self.device,
                    enforce_density=None,
                    tag='sdf-raw'
                )
                save_sdf_outputs_3d(
                    self.SDF_GINN_model,
                    JEB.domain,
                    n_opt_samples,
                    self.save_path,
                    save_prefix='',
                    save_postfix=f'-ginn-{it:06d}',
                    iso_level=0.0,
                    device=self.device,
                    enforce_density=self.enforce_density,
                    tag='sdf-constrained'
                )

                geometry_losses = GINN_Losses(
                    self.SDF_GINN_model,
                    self.test_case,
                    self.GINN_hparams,
                    self.PH,
                    self.boundary_sampler,
                    self.enforce_density
                )

                coords_eval = next(self.point_sampler).to(self.device)

                eik_loss = geometry_losses.eikonal_loss(coords_eval)
                env_loss = geometry_losses.design_envelope_loss()
                connect_loss = geometry_losses.connectivity_loss()
                hole_loss = geometry_losses.holes_loss()
                int_loss = geometry_losses.interface_loss()
                norm_loss = geometry_losses.surface_normal_loss()
                thick_loss = geometry_losses.prescribed_thickness_loss(coords_eval)
                prohib_loss = geometry_losses.prohibited_region_loss(coords_eval)
                smooth_raw = geometry_losses.smoothness_loss(it)

                smooth_loss_scaled = self.loss_ramp_up(
                    it,
                    self.GINN_hparams['curv_start_epoch'],
                    self.GINN_hparams['curv_ramp_epochs']
                ) * smooth_raw

                total_loss = (
                    eik_loss + env_loss + connect_loss + hole_loss + int_loss +
                    norm_loss + thick_loss + prohib_loss + smooth_loss_scaled
                )

                self.ginn_hist_iters.append(it)
                self.ginn_hist_eik.append(eik_loss.item())
                self.ginn_hist_env.append(env_loss.item())
                self.ginn_hist_connect.append(connect_loss.item())
                self.ginn_hist_hole.append(hole_loss.item())
                self.ginn_hist_int.append(int_loss.item())
                self.ginn_hist_norm.append(norm_loss.item())
                self.ginn_hist_thick.append(thick_loss.item())
                self.ginn_hist_prohib.append(prohib_loss.item())
                self.ginn_hist_smooth.append(smooth_loss_scaled.item())
                self.ginn_hist_total.append(total_loss.item())

                self._log_ginn_losses_csv(
                    it,
                    eik_loss.item(),
                    env_loss.item(),
                    connect_loss.item(),
                    hole_loss.item(),
                    int_loss.item(),
                    norm_loss.item(),
                    thick_loss.item(),
                    prohib_loss.item(),
                    smooth_loss_scaled.item(),
                    total_loss.item()
                )

                self._save_ginn_training_curves(it)
                self._maybe_log_eval_metrics(phase="GINN", it=it, model_kind="sdf")

        final_it = int(self.n_pre_training_iterations_GINN - 1)
        if len(self.ginn_hist_iters) == 0 or int(self.ginn_hist_iters[-1]) != final_it:
            geometry_losses = GINN_Losses(
                self.SDF_GINN_model,
                self.test_case,
                self.GINN_hparams,
                self.PH,
                self.boundary_sampler,
                self.enforce_density
            )
            coords_eval = next(self.point_sampler).to(self.device)

            eik_loss = geometry_losses.eikonal_loss(coords_eval)
            env_loss = geometry_losses.design_envelope_loss()
            connect_loss = geometry_losses.connectivity_loss()
            hole_loss = geometry_losses.holes_loss()
            int_loss = geometry_losses.interface_loss()
            norm_loss = geometry_losses.surface_normal_loss()
            thick_loss = geometry_losses.prescribed_thickness_loss(coords_eval)
            prohib_loss = geometry_losses.prohibited_region_loss(coords_eval)
            smooth_raw = geometry_losses.smoothness_loss(final_it)

            smooth_loss_scaled = self.loss_ramp_up(
                final_it,
                self.GINN_hparams['curv_start_epoch'],
                self.GINN_hparams['curv_ramp_epochs']
            ) * smooth_raw

            total_loss = (
                eik_loss + env_loss + connect_loss + hole_loss + int_loss +
                norm_loss + thick_loss + prohib_loss + smooth_loss_scaled
            )

            self._log_ginn_losses_csv(
                final_it,
                eik_loss.item(),
                env_loss.item(),
                connect_loss.item(),
                hole_loss.item(),
                int_loss.item(),
                norm_loss.item(),
                thick_loss.item(),
                prohib_loss.item(),
                smooth_loss_scaled.item(),
                total_loss.item()
            )

            self._maybe_log_eval_metrics(phase="GINN", it=final_it, model_kind="sdf")

        print("\n\n")

        # ---- Initialize volume_ratio ------
        grid_xyz = get_grid_centers(JEB.domain, n_opt_samples).astype(np.float32)
        xt = torch.tensor(grid_xyz, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            sdf_field = self.SDF_GINN_model(xt).squeeze()
            occ = torch.sigmoid(-1000 * sdf_field).clamp(0, 1)
            init_vol_frac = float(occ.mean().item())

        start_ratio = self.volume_ratio if (self.volume_ratio is not None) else init_vol_frac
        self.volume_ratio = self._safe_clip_ratio(start_ratio, self.vol_frac_min, self.vol_frac_max)
        self.density_GINN_model.volume_ratio = float(self.volume_ratio)
        target_volume = float(self.volume_ratio) * float(JEB.domain_volume)

        # ===================== (2) density-GINN init =====================
        print("=== Pre-Training density GINN ===")
        for _ in tqdm(range(self.n_pre_training_iterations_density_GINN), desc="density-GINN Init"):
            coords = next(GINN_point_sampler)
            self.density_PINN_initialization_step(coords, density_GINN_optimizer)
        print()

        # ===================== (3) PINN warm-up =====================
        print("=== Pre-Training PINN ===")
        for _ in tqdm(range(self.n_pre_training_iterations_PINN), desc="PINN Pre-Training"):
            coords = next(PINN_point_sampler)
            self.PINN_update_step(coords, PINN_optimizer, n_sim_samples)
        print()

        # ---- Initial geometry saves (OBJ + iso) ----
        save_density_outputs_3d(
            self.density_GINN_model, JEB.domain, n_opt_samples,
            self.enforce_density, self.save_path, '',
            '-initial', iso_level=self.plot_threshold
        )
        save_density_outputs_3d_unconstrained(
            self.density_GINN_model, JEB.domain, n_opt_samples,
            self.save_path, '', '-initial', iso_level=self.plot_threshold
        )

        print('Saved initial geometry to', self.save_path)

        pts0, u0p, v0p, w0p, vm0p = predict_uvw_sigma_points_3d(
            self.u_model, self.v_model, self.w_model, self.density_GINN_model,
            rho_threshold_plot=self.plot_threshold,
            n_eval_points=self.n_eval_points_plot,
            max_batch=self.max_batch_plot,
            device=self.device,
            test_case=self.test_case
        )
        png0 = os.path.join(self.save_path, f"uvwvm-{0:06d}.png")
        save_uvw_sigma_to_file_3d(pts0, u0p, v0p, w0p, vm0p, png0, title="Iteration 0")

        sigma_metric0, sigma_max0, vol0 = self._measure_stress_volume_3d()
        comp0 = self._estimate_compliance_internal_energy_3d(n_opt_samples)

        self.hist_iters.append(0)
        self.hist_sigma_metric.append(sigma_metric0)
        self.hist_sigma_max.append(sigma_max0)
        self.hist_volume.append(vol0)
        self.hist_compliance.append(comp0)

        self._save_training_curves(0)
        self._log_optimization_metrics_csv(0, sigma_metric0, sigma_max0, vol0, comp0)
        self._maybe_log_eval_metrics(phase="TOPO", it=0, model_kind="density")
        self.sigma_ema = sigma_metric0

        # ===================== Second optimization stage =====================
        for it in range(1, self.n_opt_iterations + 1):
            print(f"\n=== Optimization iteration {it}/{self.n_opt_iterations} ===")

            # ---- Volume controller ---------
            r = self.sigma_ema / (self.sigma_allow + 1e-20)

            if r < 1.0 - self.stress_tol:
                gap = 1.0 - r
                step = float(np.clip(self.alpha_decrease * gap, self.vol_step_min, self.vol_step_max))
                new_ratio = float(self.volume_ratio) * (1.0 - step)
            elif r > 1.0 + self.stress_tol:
                gap = r - 1.0
                step = float(np.clip(self.beta_increase * gap, self.vol_step_min, self.vol_step_max))
                new_ratio = float(self.volume_ratio) * (1.0 + step)
            else:
                new_ratio = float(self.volume_ratio)

            self.volume_ratio = self._safe_clip_ratio(new_ratio, self.vol_frac_min, self.vol_frac_max)
            self.density_GINN_model.volume_ratio = float(self.volume_ratio)
            target_volume = float(self.volume_ratio) * float(JEB.domain_volume)


            old_densities_list, sensitivities_list, coords_list = [], [], []

            topo_funcs = topology_optimization(
                self.u_model,
                self.v_model,
                self.w_model,
                self.density_GINN_model,
                n_opt_samples,
                self.training_hparams,
                self.enforce_density
            )

            for _ in tqdm(range(self.n_opt_batches), desc="Computing Target Densities"):
                coords = next(GINN_point_sampler)
                densities, sensitivities = topo_funcs.compute_sensitivities(coords)

                densities, sensitivities = self.enforce_density.apply(
                    coords, densities, sensitivities, n_opt_samples, JEB.domain
                )

                filtered_sensitivities = topo_funcs.apply_sensitivity_filter(
                    coords, densities, sensitivities,
                    n_samples=n_opt_samples, domain=JEB.domain, radius=self.filter_radius
                )

                old_densities_list.append(densities)
                sensitivities_list.append(filtered_sensitivities)
                coords_list.append(coords)

            target_density_list = topo_funcs.compute_target_densities(
                old_densities_list, sensitivities_list,
                sample_volume=JEB.domain_volume,
                target_volume=target_volume,
                max_move=0.2,
                damping_parameter=0.5
            )

            for coords, target_density in tqdm(
                list(zip(coords_list, target_density_list)),
                desc="Optimizing density field",
                total=len(coords_list)
            ):
                self.topology_optimization_step(density_GINN_optimizer, coords, target_density)

            # ---- Update PINN ----
            for _ in tqdm(range(self.n_sim_iterations), desc="Updating PINN"):
                coords = next(PINN_point_sampler)
                self.PINN_update_step(coords, PINN_optimizer, n_sim_samples)

            # ---- Stress & volume on binary geometry ----
            sigma_metric, sigma_max, current_volume = self._measure_stress_volume_3d()

            # ---- EMA to fight oscillations ----
            if self.sigma_ema is None:
                self.sigma_ema = sigma_metric
            else:
                a = self.sigma_ema_alpha
                self.sigma_ema = (1.0 - a) * self.sigma_ema + a * sigma_metric

            # ---- Compliance on current SIMP design ----
            comp_now = self._estimate_compliance_internal_energy_3d(n_opt_samples)

            self.hist_iters.append(it)
            self.hist_sigma_metric.append(sigma_metric)
            self.hist_sigma_max.append(sigma_max)
            self.hist_volume.append(current_volume)
            self.hist_compliance.append(comp_now)

            self._log_optimization_metrics_csv(it, sigma_metric, sigma_max, current_volume, comp_now)


            if (it % self.save_interval) == 0:
                self._save_training_curves(it)

            print(
                f"[iter {it}] C={comp_now:.4g},  V={current_volume:.4g},  "
                f"σ_metric={sigma_metric:.4g} (σ_allow={self.sigma_allow:.4g})  "
                f"| vol_ratio→{self.volume_ratio:.4f}  "
                f"(target={target_volume/float(JEB.domain_volume):.4f}·V_domain)  "
                f"| σ_max={sigma_max:.4g}"
            )

            if (it % self.save_interval) == 0:
                self._maybe_log_eval_metrics(phase="TOPO", it=it, model_kind="density")
                print('Saving geometry & models to', self.save_path)

                save_density_outputs_3d(
                    self.density_GINN_model, JEB.domain, n_opt_samples,
                    self.enforce_density, self.save_path, '',
                    f'-{it:06d}', iso_level=self.plot_threshold
                )

                save_density_outputs_3d_unconstrained(
                    self.density_GINN_model, JEB.domain, n_opt_samples,
                    self.save_path, '', f'-{it:06d}', iso_level=self.plot_threshold
                )

                pts_it, u_it, v_it, w_it, vm_it = predict_uvw_sigma_points_3d(
                    self.u_model, self.v_model, self.w_model, self.density_GINN_model,
                    rho_threshold_plot=self.plot_threshold,
                    n_eval_points=self.n_eval_points_plot,
                    max_batch=self.max_batch_plot,
                    device=self.device,
                    test_case=self.test_case
                )
                png_it = os.path.join(self.save_path, f"uvwvm-{it:06d}.png")
                save_uvw_sigma_to_file_3d(pts_it, u_it, v_it, w_it, vm_it, png_it, title=f"Iteration {it}")

                torch.save(self.density_GINN_model.state_dict(), os.path.join(self.save_path, f"density_GINN-{it:06d}.pth"))
                torch.save(self.u_model.state_dict(), os.path.join(self.save_path, f"u_model-{it:06d}.pth"))
                torch.save(self.v_model.state_dict(), os.path.join(self.save_path, f"v_model-{it:06d}.pth"))
                torch.save(self.w_model.state_dict(), os.path.join(self.save_path, f"w_model-{it:06d}.pth"))

                self._save_topology_density_curves(it)
                self._log_optimization_losses_csv(it, comp_now, sigma_metric, sigma_max, current_volume)

        # ---- Final snapshot ----
        save_density_outputs_3d(
            self.density_GINN_model, JEB.domain, n_opt_samples,
            self.enforce_density, self.save_path, '',
            '-final', iso_level=self.plot_threshold
        )

        if len(self.hist_iters) > 0:
            self._save_training_curves(self.hist_iters[-1])

        torch.save(self.density_GINN_model.state_dict(), os.path.join(self.save_path, "density_GINN-final.pth"))
        torch.save(self.u_model.state_dict(), os.path.join(self.save_path, "u_model-final.pth"))
        torch.save(self.v_model.state_dict(), os.path.join(self.save_path, "v_model-final.pth"))
        torch.save(self.w_model.state_dict(), os.path.join(self.save_path, "w_model-final.pth"))

        print("Done. Results in:", self.save_path)


In [None]:

hparams_model = {
    'Model_type'         : 'SIREN',
    'num_hidden_layers'  : 3,
    'num_hidden_neurons' : 280,

    'SIREN_hparams': {
        'Model_type'    : 'SIREN',
        'layers'        : [180, 180, 180, 180],
        'dimensionality': 3,
        'w0_initial'    : 60,
        'w0'            : 1,
        'skip_connection': True,
    },
    'SIREN_SDF_hparams': {
        'Model_type'    : 'SIREN',
        'layers'        : [180, 180, 180, 180],
        'dimensionality': 3,
        'w0_initial'    : 15,
        'w0'            : 2,
        'skip_connection': True,
    },
    'WIRE_hparams':
                    {   'Model_type' : 'WIRE',
                        'layers'     : [180,180,180,180],
                        'dimensionality': 3,
                        'w0_initial' : 15,
                        'w0'         : 2,
                        'sigma0'     : 1,
                        'sigma0_initial' : 1,
                        'layer_type': 'real_gabor',
                        'trainable' :  False,
                        'skip_connection' : True,
                    },
    'MLP_hparams': {
        'Model_type'    : 'MLP',
        'layers'        : [180, 180, 180, 180],
        'dimensionality': 3,
        'activation_function': 'relu',
        'use_bias'      : True,
        'use_batch_norm': False,
        'use_dropout'   : False,
        'dropout_rate'  : 0.1,
        'skip_connection': True,
    },
}

hparams_feature_expansion = {
    'Feature Type'    : 'None',
    'Num Frequencies' : 3,
    'Max Frequency'   : 100,
}


density_constraint_hparams = {
    'enabled'            : True,
    'priority'           : "keep_over_prohibit",
}


topo_hparams = {
    'save_path'     : "./JEB_new_density",
    'volume_ratio'  : None,
    'lr_PINN'       : 3e-4,
    'lr_GINN'       : 3e-4,
    'save_interval' : 5,
    'filter_radius' : 2.0,
    'n_opt_iterations'              :  100,
    'n_sim_iterations'              :  1000,
    'n_pre_training_iterations_PINN': 2500,
    'n_pre_training_iterations_density_GINN': 2500,
    'n_pre_training_iterations_GINN': 300,
    'n_opt_batches' : 50,
    'seed'          : 43,
    'rho_treshold'  : 0.25,
    'stress_metric' : "percentile",
    'sigma_allow'   :  400.0,
    'stress_tol'    :  0.02,
    'stress_percentile':  0.9995,
    'ks_rho'        :     80.0,
    'ks_size_correction': False,
    'vol_frac_min'  : 0.02,
    'vol_frac_max'  : 0.98,
    'vol_step_min'  : 0.005,
    'vol_step_max'  : 0.05,
    'alpha_decrease': 0.50,
    'beta_increase' : 0.25,
    'n_eval_points_plot': 400_000,
    'max_batch_plot'    : 120_000,
}


training_hparams = {
    'total_sample_points': 80000,
    'batch_size': 20000,
    'num_neumann_points': 10000,
    'dirichlet_pts': 5000,
    'mollifier_alpha': 1,
    'density_alpha': 2,
    'density_exponent': 3,
    "grad_clipping_on": True,
    "grad_clip": 0.5,
    "auto_clip_on": True,
    "auto_clip_percentile": 0.9,
    "auto_clip_hist_len": 100,
    "auto_clip_min_len": 10,
    "refine_study": False,
}


GINN_hparams ={
    'eikonal_loss_weight' : 0.1,
    'envelope_loss_weight': 10,
    'connectivity_loss_weight': 50,
    'holes_loss_weight': 50,
    'interface_loss_weight': 1,
    'prescribed_normals_loss_weight': 1,
    'prescribed_thickness_loss_weight': 1,
    'prohibited_region_loss_weight': 1,
    'smoothness_loss_weight': 1e-2,
    'num_points_envelope_loss': 12000,
    'num_points_connectivity_loss': 60000,
    'envelope_extension_factor': 0.2,
    'num_points_interface_loss': 4096,
    'num_points_normals_loss': 4096,
    'clip_max_value': 1.0e+6,
    'clip_min_value': 0,
    'max_curv': 0,
    'curv_start_epoch': 200,
    'curv_ramp_epochs': 300,
    'envelope_loss_weight_density':     0,
    'connectivity_loss_weight_density': 0,
}

adaptive_weighting_hparams = {
    'use_adaptive_weighting': True,
    'alpha': 0.90,
    'gamma': 1e-2,
    'epsilon': 1e-8,
    'objective_function': True,
    'objective_losses': ['Smoothness Loss'],
}


In [None]:
density_GINN = density_GINN(hparams_model['SIREN_SDF_hparams'],
                            hparams_feature_expansion,
                            density_alpha=training_hparams['density_alpha'],
                            volume_ratio=topo_hparams['volume_ratio'])

SDF_GINN = SDF_GINN(hparams_model['SIREN_SDF_hparams'],
                    hparams_feature_expansion)

u_model = PINN(hparams_model['SIREN_hparams'],
               hparams_feature_expansion,
               training_hparams['mollifier_alpha'])

v_model = PINN(hparams_model['SIREN_hparams'],
               hparams_feature_expansion,
               training_hparams['mollifier_alpha'])

w_model = PINN(hparams_model['SIREN_hparams'],
               hparams_feature_expansion,
               training_hparams['mollifier_alpha'])


Shape_Generator = model_training(u_model,
                                 v_model,
                                 w_model,
                                 density_GINN,
                                 SDF_GINN,
                                 training_hparams,
                                 topo_hparams,
                                 density_constraint_hparams,
                                 GINN_hparams,
                                 adaptive_weighting_hparams,
                                 JEB)

In [None]:
Shape_Generator.generate_geometry()