In [None]:
import os
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
google_drive_path = ''
os.chdir(google_drive_path)
print("Current working directory:", os.getcwd())
!ls

!python -m pip install trimesh
!python -m pip install rtree
!python -m pip install cripser==0.0.15 #scikit-image trimesh plotly

In [None]:
import os
import sys
import time
import math
import json
import random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import trimesh
from tqdm import trange
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# To allow imports from parent directory
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) 


In [None]:
from Functions.Point_Sampling.point_sampler import Point_Sampler
from Test_Cases.Jet_Engine_Bracket.JEB_Master_object import JEB_Master_object
from Test_Cases.Bridge_around_object.BRIDGE_Master_object import BRIDGE_Master_Object
from Functions.Training.Properties import Properties 
from Functions.Plotting_functions.training_curves.JEB_PINN_curves import *
from Functions.Computations.L2_error import compute_L2_errors_3d 
from Functions.logging.JEB_PINN_logging import save_metrics_csv_3d 
from Functions.Plotting_functions.JEB_PINN_results import plot_results 
from Functions.utils import * 

from Functions.Data_preprocessing_functions.interpolate_from_PointCloud import (
    interpolate_from_point_cloud,
)
from Models.GINN_Models.GINN import GINN

from File_Paths.file_paths import (
    mesh_path,
    point_cloud_path,
    data_path,
)

BRIDGE = BRIDGE_Master_Object(Normalize=True, Symmetry=False)
BRIDGE.create_interfaces()

JEB = JEB_Master_object(Normalize=True, Symmetry=False)
JEB.create_interfaces()
material_properties = Properties(test_case=BRIDGE)

device = torch.device(
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu" 
)

In [None]:

s = float(JEB.domain_scaling_factor)                
center_mm = np.asarray(JEB.domain_center, dtype=float)


# Load and normalize 411 geometry 
simple_geometry_mesh_filename = "411.obj"
simple_geometry_PC_filename   = "411_PC_data.csv"

simple_geometry_mesh = trimesh.load(os.path.join(mesh_path, simple_geometry_mesh_filename))
simple_geometry_mesh.fix_normals()
simple_geometry_mesh.fill_holes()

simple_geometry_point_cloud = pd.read_csv(
    os.path.join(point_cloud_path, simple_geometry_PC_filename)
).to_numpy()

# Normalize with JEB transform: 
simple_geometry_mesh.apply_translation(-center_mm)
simple_geometry_mesh.apply_scale(s)

simple_geometry_point_cloud[:, :3] = (simple_geometry_point_cloud[:, :3] - center_mm[np.newaxis, :]) * s
simple_geometry_point_cloud[:, 3] *= s  

solid_volume_hat = float(simple_geometry_mesh.volume)
solid_volume_mm3 = solid_volume_hat / (s**3)



# Load FEM reference from Excel 
def load_JEB_FEM_reference():
    filename = "JEB_411_validation_FEM_mine.xlsx"
    file_path = Path(data_path) / filename

    if not file_path.exists():
        print(f"[WARN] FEM file not found: {file_path}")
        return None

    df = pd.read_excel(file_path)
    df.columns = df.columns.str.strip()

    def norm(name: str) -> str:
        return "".join(name.lower().split())

    cols_norm = {norm(c): c for c in df.columns}

    def find_col(possible_names):
        for name in possible_names:
            key = norm(name)
            if key in cols_norm:
                return cols_norm[key]
        raise KeyError(
            f"None of {possible_names} found.\nAvailable columns: {list(df.columns)}"
        )

    col_x = find_col(["X Location (mm)"])
    col_y = find_col(["Y Location (mm)"])
    col_z = find_col(["Z Location (mm)"])

    coords_mm = df[[col_x, col_y, col_z]].to_numpy(dtype=float)

    # normalized coordinates consistent with JEB
    coords_scaled = (coords_mm - center_mm) * s

    col_ux = find_col(["X Disp (mm)", "X disp (mm)", "X Displacement (mm)"])
    col_uy = find_col(["Y Disp (mm)", "Y disp (mm)", "Y Displacement (mm)"])
    col_uz = find_col(["Z Disp (mm)", "Z disp (mm)", "Z Displacement (mm)"])

    x_disp = df[col_ux].to_numpy(dtype=float)
    y_disp = df[col_uy].to_numpy(dtype=float)
    z_disp = df[col_uz].to_numpy(dtype=float)

    col_svm = find_col(
        [
            "Equivalent (von-Mises) Stress (MPa)",
            "Equivalent von-Mises Stress (MPa)",
            "Equivalent (von Mises) Stress (MPa)",
        ]
    )
    sigma_vm = df[col_svm].to_numpy(dtype=float)

    fem_ref = {
        "coords_mm": coords_mm,
        "coords_scaled": coords_scaled,
        "x_disp": x_disp,
        "y_disp": y_disp,
        "z_disp": z_disp,
        "sigma_vm": sigma_vm,
    }

    print("Loaded FEM:", file_path)
    print("FEM nodes:", coords_mm.shape[0])

    return fem_ref



def fem_inside_mask_from_sdf(fem_ref, sdf_inside_is_negative=True):
    if fem_ref is None:
        return None

    coords_hat = fem_ref["coords_scaled"]
    sdf_hat = interpolate_from_point_cloud(simple_geometry_point_cloud, coords_hat, quantity="SDF")
    sdf_hat = np.asarray(sdf_hat).reshape(-1)

    if sdf_inside_is_negative:
        inside = sdf_hat < 0.0
    else:
        inside = sdf_hat > 0.0

    print(f"inside fraction by SDF: {inside.mean():.4f}")
    fem_ref["inside_mask"] = inside
    fem_ref["sdf_hat"] = sdf_hat
    return fem_ref


In [None]:
class PINN(torch.nn.Module):
    def __init__(self, hparams_model, hparams_feature_expansion, mollifier_alpha):
        super().__init__()
        self.mollifier_alpha = mollifier_alpha
        self.model = GINN(
            JEB,
            feature_expansion=hparams_feature_expansion,
            Model_hyperparameters=hparams_model,
        )

    @staticmethod
    def enforce_dirichlet_BC(alpha, u, coords):
        device_local = u.device if isinstance(coords, torch.Tensor) else None

        radius = JEB.bolt_interface_radius
        depth = JEB.bolt_interface_depth  

        centroid1 = JEB.bolt1_centroid
        centroid2 = JEB.bolt2_centroid
        centroid3 = JEB.bolt3_centroid
        centroid4 = JEB.bolt4_centroid

        centroids = torch.tensor(
            [
                centroid1[0:2],
                centroid2[0:2],
                centroid3[0:2],
                centroid4[0:2],
            ],
            dtype=torch.float32,
            device=device_local,
        )

        x = coords[:, 0]
        y = coords[:, 1]
        xy_coords = torch.stack([x, y], dim=1)

        multiplier_total = torch.ones(coords.shape[0], device=device_local)

        for centroid in centroids:
            dist = torch.norm(xy_coords - centroid, dim=1)

            inside_mask = dist <= radius
            outside_mask = dist > radius

            d_inside = torch.abs(dist[inside_mask] - radius)
            d_outside = torch.abs(radius - dist[outside_mask])

            m_inside = torch.tanh(alpha * d_inside)
            m_outside = torch.tanh(alpha * d_outside)

            multiplier = torch.ones_like(dist)
            multiplier[inside_mask] = m_inside
            multiplier[outside_mask] = m_outside

            multiplier_total *= multiplier

        return u * multiplier_total.unsqueeze(1)

    def forward(self, coords):
        u = self.model(coords)
        return self.enforce_dirichlet_BC(self.mollifier_alpha, u, coords)

In [None]:
class PINN_Loss(Properties):
    def __init__(self, u_model, v_model, w_model, training_hparams):
        super().__init__(test_case=JEB)

        self.u_model = u_model.to(device)
        self.v_model = v_model.to(device)
        self.w_model = w_model.to(device)

        self.num_neumann_pts = int(training_hparams["num_neumann_points"])
        self.p = float(training_hparams["density_exponent"])
        self.rho_min = float(training_hparams["rho_min"])

        self.V_solid_mm3 = float(solid_volume_mm3)

    def ritz_loss(self, coords_hat, rho):
        neumann_hat = self.interfaces.sample_points_on_neumann_boundary(
            self.num_neumann_pts, "vertical", "torch_tensor"
        ).to(device)
        neumann_hat.requires_grad_(True)

        R_hat = float(JEB.pinn_interface_radius)
        width_hat = float(JEB.pinn_interface_width)
        arc_area_hat = 2.0 * math.pi * R_hat * width_hat
        arc_area_mm2 = arc_area_hat / (s**2)

        ds_mm2 = arc_area_mm2 / neumann_hat.shape[0]

        traction_z = self.force_vector[2] / (2*arc_area_mm2)


        w_neu = self.w_model(neumann_hat).squeeze(-1)

        external_energy = torch.sum(traction_z * w_neu * ds_mm2)
        x = coords_hat.detach().clone().requires_grad_(True)

        u = self.u_model(x)  
        v = self.v_model(x)  
        w = self.w_model(x)  

        gu_hat = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0]
        gv_hat = torch.autograd.grad(v, x, grad_outputs=torch.ones_like(v), create_graph=True, retain_graph=True)[0]
        gw_hat = torch.autograd.grad(w, x, grad_outputs=torch.ones_like(w), create_graph=True, retain_graph=True)[0]


        gu = s * gu_hat
        gv = s * gv_hat
        gw = s * gw_hat

        e11 = gu[:, 0]
        e22 = gv[:, 1]
        e33 = gw[:, 2]
        e12 = 0.5 * (gu[:, 1] + gv[:, 0])
        e13 = 0.5 * (gu[:, 2] + gw[:, 0])
        e23 = 0.5 * (gv[:, 2] + gw[:, 1])
        tr = e11 + e22 + e33

        rho = rho.view(-1).clamp(0.0, 1.0)
        rho_p = rho.pow(self.p) if self.p != 0 else torch.ones_like(rho)
        if self.rho_min > 0.0:
            rho_p = self.rho_min + (1.0 - self.rho_min) * rho_p

        lam = self.lame_lambda * rho_p
        mu  = self.lame_mu * rho_p

        s11 = 2 * mu * e11 + lam * tr
        s22 = 2 * mu * e22 + lam * tr
        s33 = 2 * mu * e33 + lam * tr
        s12 = 2 * mu * e12
        s13 = 2 * mu * e13
        s23 = 2 * mu * e23

        psi = 0.5 * (
            s11 * e11 + s22 * e22 + s33 * e33
            + 2 * s12 * e12 + 2 * s13 * e13 + 2 * s23 * e23
        )  

        internal_energy = self.V_solid_mm3 * psi.mean()  

        return internal_energy - external_energy

In [None]:

class FixedGeometryDataset(Dataset):
    """
    Samples points in JEB.domain (normalized coords).
    Interpolates SDF from normalized point cloud.
    """
    def __init__(self, num_points: int, training_hparams: dict):
        super().__init__()
        self.device = device

        point_sampler = Point_Sampler(
            JEB.domain,
            JEB.interfaces,
            num_points_domain=num_points,
            num_points_interface=0,
        )
        pts_hat = next(point_sampler)  # torch CPU

        pts_np = pts_hat.detach().cpu().numpy()
        sdf_np = interpolate_from_point_cloud(simple_geometry_point_cloud, pts_np, quantity="SDF")
        sdf_hat = torch.tensor(sdf_np, dtype=torch.float32, device=self.device).view(-1)

        inside_is_negative = bool(training_hparams["sdf_inside_is_negative"])
        if inside_is_negative:
            rho = (sdf_hat < 0.0).float()
        else:
            rho = (sdf_hat > 0.0).float()

        if bool(training_hparams["solid_only_sampling"]):
            mask = rho > 0.5
            self.points = pts_hat.to(self.device)[mask]
            self.density = rho[mask]
        else:
            self.points = pts_hat.to(self.device)
            self.density = rho

        with torch.no_grad():
            frac_solid = float((rho > 0.5).float().mean().item())
            self.stats = {
                "solid_frac_in_domain": frac_solid,
                "points_used": int(self.points.shape[0]),
            }

    def __len__(self):
        return self.points.shape[0]

    def __getitem__(self, idx):
        return self.points[idx], self.density[idx]





In [None]:

hparams_model = {
    "SIREN_hparams": {
        "Model_type": "SIREN",
        "layers": [180, 180, 180, 180, 180],
        "dimensionality": 3,
        "w0_initial": 60,
        "w0": 10,
        "skip_connection": True,
    },
}

hparams_feature_expansion = {
    "Feature Type": "None",
    "Num Frequencies": 3,
    "Max Frequency": 100,
}

training_hparams = {
    "total_points": 900_000,
    "batch_size": 300_000,
    "num_epochs": 5002,
    "learning_rate": 1e-3,
    "weight_decay": 1e-4,
    "gamma": 0.75,
    "num_neumann_points": 70_000,
    "mollifier_alpha": 1.0,
    "density_exponent": 10,  
    "rho_min": 0.0,         
    "density_mode": "binary_sdf",
    "sdf_inside_is_negative": True, 
    "solid_only_sampling": True,
    "plot_interval": 200,
    "save_path": "./JEB_PINN_val_final",
    "seed": 43,

    "resume": False,                 # set False to force a fresh run
    "save_every_epochs": 200,        # checkpoint 
}


In [None]:

JEB_FEM_ref = load_JEB_FEM_reference() 
SDF_INSIDE_IS_NEGATIVE = True
JEB_FEM_ref = fem_inside_mask_from_sdf(JEB_FEM_ref, sdf_inside_is_negative=SDF_INSIDE_IS_NEGATIVE)

# Training loop 
def train_pinn_for_fixed_geometry():
    set_random_seed(training_hparams["seed"])

    save_dir = training_hparams["save_path"]
    os.makedirs(save_dir, exist_ok=True)

    # --- checkpoint paths ---
    latest_ckpt_path = os.path.join(save_dir, "checkpoint_latest.pth")
    best_model_path  = os.path.join(save_dir, "best_model_optionA_fixed.pth")

    # --- models ---
    u_model = PINN(
        hparams_model["SIREN_hparams"],
        hparams_feature_expansion,
        training_hparams["mollifier_alpha"],
    ).to(device)

    v_model = PINN(
        hparams_model["SIREN_hparams"],
        hparams_feature_expansion,
        training_hparams["mollifier_alpha"],
    ).to(device)

    w_model = PINN(
        hparams_model["SIREN_hparams"],
        hparams_feature_expansion,
        training_hparams["mollifier_alpha"],
    ).to(device)

    params = list(u_model.parameters()) + list(v_model.parameters()) + list(w_model.parameters())

    optimizer = torch.optim.Adam(
        params=params,
        lr=training_hparams["learning_rate"],
        weight_decay=training_hparams["weight_decay"],
    )

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=int(training_hparams["num_epochs"] / 4),
        gamma=training_hparams["gamma"],
    )

    loss_class_disp = PINN_Loss(u_model, v_model, w_model, training_hparams)

    # --- resume training logic --- 
    start_epoch = 0
    best_loss = float("inf")
    history = {"epoch": [], "loss": [], "L2_u": [], "L2_v": [], "L2_w": [], "L2_s": []}

    if training_hparams.get("resume", True) and os.path.exists(latest_ckpt_path):
        print(f"[RESUME] Loading checkpoint: {latest_ckpt_path}")
        last_epoch, best_loss, history = load_checkpoint(
            latest_ckpt_path,
            u_model, v_model, w_model,
            optimizer, scheduler,
            map_location=device
        )
        start_epoch = last_epoch + 1
        print(f"[RESUME] Resuming from epoch {start_epoch} | best_loss={best_loss:.6e}")
    else:
        print("[RESUME] Starting fresh training run.")

    torch.autograd.set_detect_anomaly(True)

    save_every = int(training_hparams.get("save_every_epochs", 10))
    plot_interval = int(training_hparams["plot_interval"])

    try:
        for epoch in trange(start_epoch, training_hparams["num_epochs"]):
            dataset = FixedGeometryDataset(training_hparams["total_points"], training_hparams)

            if (epoch % plot_interval) == 0:
                print("[dataset]", dataset.stats)

            dataloader = DataLoader(
                dataset,
                batch_size=training_hparams["batch_size"],
                shuffle=True,
                drop_last=False,
            )

            epoch_loss = 0.0

            for coords_batch, density_batch in dataloader:
                optimizer.zero_grad(set_to_none=True)
                loss_ritz = loss_class_disp.ritz_loss(coords_batch, density_batch)
                loss_ritz.backward()
                optimizer.step()
                epoch_loss += loss_ritz.item()

            scheduler.step()
            epoch_loss /= max(1, len(dataloader))

            # --- metrics ---
            L2_u, L2_v, L2_w, L2_s = compute_L2_errors_3d(
                u_model, v_model, w_model, material_properties, JEB_FEM_ref, JEB
            )

            history["epoch"].append(epoch)
            history["loss"].append(epoch_loss)
            history["L2_u"].append(0.0 if L2_u is None else L2_u)
            history["L2_v"].append(0.0 if L2_v is None else L2_v)
            history["L2_w"].append(0.0 if L2_w is None else L2_w)
            history["L2_s"].append(0.0 if L2_s is None else L2_s)

            # --- save best model (same behavior as before) ---
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                atomic_torch_save(
                    {
                        "epoch": epoch,
                        "loss": best_loss,
                        "u_model_state_dict": u_model.state_dict(),
                        "v_model_state_dict": v_model.state_dict(),
                        "w_model_state_dict": w_model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                    },
                    best_model_path, 
                )

            if (epoch % save_every) == 0 or (epoch % plot_interval) == 0:
                save_checkpoint(
                    latest_ckpt_path,
                    epoch=epoch,
                    u_model=u_model,
                    v_model=v_model,
                    w_model=w_model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    best_loss=best_loss,
                    history=history,
                )

            if (epoch % plot_interval) == 0:
                print(
                    f"Epoch {epoch}, loss={epoch_loss:.6e}, "
                    f"lr={scheduler.optimizer.param_groups[0]['lr']:.3e}, "
                    f"best_loss={best_loss:.6e}, "
                    f"L2_u={history['L2_u'][-1]:.3e}, "
                    f"L2_v={history['L2_v'][-1]:.3e}, "
                    f"L2_w={history['L2_w'][-1]:.3e}, "
                    f"L2_s={history['L2_s'][-1]:.3e}"
                )

                plot_results(material_properties, u_model, v_model, w_model, JEB_FEM_ref, save_dir, epoch,JEB)

                fname_curves_lin = os.path.join(save_dir, f"training_curves_3d_{epoch:06d}.png")
                fname_curves_log = os.path.join(save_dir, f"training_curves_3d_log_{epoch:06d}.png")
                save_training_curves_3d(history, fname_curves_lin)
                save_training_curves_log_3d(history, fname_curves_log)
                save_metrics_csv_3d(history, save_dir)

        # final save
        save_checkpoint(
            latest_ckpt_path,
            epoch=training_hparams["num_epochs"] - 1,
            u_model=u_model,
            v_model=v_model,
            w_model=w_model,
            optimizer=optimizer,
            scheduler=scheduler,
            best_loss=best_loss,
            history=history,
        )

    except KeyboardInterrupt:
        print("\n[INTERRUPT] Caught KeyboardInterrupt — saving latest checkpoint before exit.")
        safe_epoch = history["epoch"][-1] if len(history["epoch"]) else (start_epoch - 1)
        save_checkpoint(
            latest_ckpt_path,
            epoch=safe_epoch,
            u_model=u_model,
            v_model=v_model,
            w_model=w_model,
            optimizer=optimizer,
            scheduler=scheduler,
            best_loss=best_loss,
            history=history,
        )
        raise

    print("Training finished.")
    print("Best loss:", best_loss)
    print("Best model path:", best_model_path)
    print("Latest checkpoint:", latest_ckpt_path)

    return best_model_path, history


In [None]:

training_hparams.update({
    "resume": True,                 # set False to force a fresh run
    "save_every_epochs": 200,        # checkpoint cadence (in epochs)
})

best_model_path, history = train_pinn_for_fixed_geometry()