In [1]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
import pyvista as pv
import pint
from typing import List

In [2]:
ROOT = Path().cwd().parent
PARAMETER_SPACE = "10"
ROOT
DATA_TYPE = "Training"
time_step_idx = -1
ureg = pint.get_application_registry()

In [10]:
mapped_root = ROOT / "data" / PARAMETER_SPACE / f"{DATA_TYPE}Mapped"
assert mapped_root.exists()
mapped_folders : List[Path]= [folder for folder in mapped_root.iterdir() if not folder.is_file()]

In [11]:
mapped_dic = {}
for folder in mapped_folders:
    path = folder / "Exports" /  "Training_entropy_gen_number_therm.npy"
    if not path.exists():
        print(f"No entropy export found in {folder}")
        continue
    entropy_num = np.load(path)
    vtk_file = [file for file in folder.rglob("*.vt*")][0]
    mesh = pv.read(vtk_file)
    surf = mesh.compute_cell_sizes()
    n_cells = mesh.n_cells
    n_points = mesh.n_points
    cell_lengths = np.cbrt(np.unique(surf.cell_data['Volume']))
    assert cell_lengths.shape[0] == 1
    mapped_dic[folder.stem] = {"entropy_num" : entropy_num,
                               "n_cells": n_cells,
                               "cell_size": cell_lengths[0],
                               "n_points" : n_points}

In [5]:
original_root = ROOT / "data" / PARAMETER_SPACE / f"{DATA_TYPE}Original"
entropy_num_org = np.load(original_root / "Training_entropy_gen_number_therm.npy")
vtu_files = sorted([file for file in original_root.rglob("*.vtu")])
N_SNAPS = len(vtu_files)
n_cells_org = np.zeros((N_SNAPS, ))
n_points_org = np.zeros_like(n_cells_org)
for idx, file in enumerate(vtu_files):
    mesh = pv.read(file)
    n_cells_org[idx] = mesh.n_cells
    n_points_org[idx] = mesh.n_points

## Plot

In [6]:
colors = px.colors.sample_colorscale("turbo", [n/(N_SNAPS) for n in range(N_SNAPS)])
print(N_SNAPS)

200


In [12]:
fig = go.Figure()
plot_metric = "entropy_num" #"entropy_num"
squarred_errors = np.zeros((N_SNAPS, len(mapped_dic) + 1)) #! Only if original is also included
for idx_snap in range(N_SNAPS):
    entropies = []
    n_cells = []
    cell_sizes = []
    mapped_meshes = []

    for key, value in mapped_dic.items():
        entropies.append(value[plot_metric][idx_snap][time_step_idx])
        n_cells.append(value["n_cells"])
        cell_sizes.append(value["cell_size"])
        mapped_meshes.append(key)
        
    sorted_indices = np.argsort(np.array(n_cells))
    entropies = np.array(entropies)[sorted_indices]
    cell_sizes = np.array(cell_sizes)[sorted_indices]
    n_cells = np.array(n_cells)[sorted_indices]
    mapped_meshes_sorted = [mapped_meshes[i] for i in sorted_indices]
    
    original_n_cells = n_cells_org[idx_snap]
    original_entropy_num = entropy_num_org[idx_snap][time_step_idx]
    
    idx_org = np.searchsorted(n_cells, original_n_cells)
    n_cells = np.insert(n_cells, idx_org, original_n_cells)
    cell_sizes = np.insert(cell_sizes, idx_org, np.nan)
    entropies = np.insert(entropies, idx_org, original_entropy_num)
    mapped_meshes_sorted.insert(idx_org, "Original Comsol Mesh")
    

    rel_error = (entropies - original_entropy_num) / original_entropy_num
    squarred_errors[idx_snap, :] = (entropies - original_entropy_num)**2
    
    mapped_meshes_sorted = np.array(mapped_meshes_sorted, dtype=object)
    idx_snap_array = np.ones_like(rel_error) * idx_snap
    # Stack into 2D array: shape (n_points, 2)
    custom_data = np.column_stack((mapped_meshes_sorted, rel_error * 100, idx_snap_array))
    

    fig.add_trace(go.Scatter(x=n_cells,
                             y=entropies,
                                mode='markers+lines',
                                name=f"{idx_snap:03d}",
                                opacity=0.4,#
                                customdata=custom_data,
                                hovertemplate=
                                    'Cells: %{x}<br>'
                                    'Entropy number: %{y}<br>'
                                    'Snap Idx: %{customdata[2]:03d}<br>'
                                    'Mapped: %{customdata[0]}<br>'
                                    'Rel Error: %{customdata[1]:.3f}%<extra></extra>',
                                line=dict(color=colors[idx_snap])
                ))
    
fig.update_layout(
    showlegend=False,        # Hide the legend
    title=f"Entropy generation number vs n_cells - Training Snapshots - Parameter Space {PARAMETER_SPACE}",       # Optional: add a title
    xaxis_title="Total cells (min. cell size)",    # Optional: label for x-axis
    yaxis_title=plot_metric,     # Optional: label for y-axis    
    xaxis=dict(
        type='log',                               # Set x-axis to logarithmic scale
        tickformat='.1e',                         # Format x-axis ticks to scientific notation with 2 decimal places
        tickvals = np.delete(n_cells, idx_org),
        ticktext = [f"{n_cell:.2e} ({cell_size:.0f} m)" for (cell_size, n_cell) in zip(np.delete(cell_sizes, idx_org), np.delete(n_cells, idx_org))]
    ),
    yaxis=dict(
        tickformat='.2f',                         # Format y-axis ticks to 2 decimal places (optional)
    )
)

fig.write_image(mapped_root / f"Training_Diff_{plot_metric}.png")
fig.write_html(mapped_root / f"Training_Diff_{plot_metric}.html")
fig.show()

### Errors

In [13]:
mses = np.sum(squarred_errors, axis=0)/N_SNAPS
meses_dic = {mesh: mse for (mse, mesh) in zip(mses, mapped_meshes_sorted) }
for key, val in meses_dic.items():
    print(f"{key} - MSE : {val:>10.1e}")

s100_100_100_b0_4000_0_5000_-4000_0 - MSE :    1.2e-03
Original Comsol Mesh - MSE :    0.0e+00
s50_50_50_b0_4000_0_5000_-4000_0 - MSE :    4.0e-04


In [14]:
fig = go.Figure()
fig.add_trace(go.Scatter(
    x = np.delete(cell_sizes, idx_org),
    y = np.delete(mses, idx_org)
))

fig.update_layout(
    title=f"PS {PARAMETER_SPACE} - Entropy Generation Number MSE {DATA_TYPE}",
    xaxis_title="Control mesh cell size [m]",
    yaxis_title="Mean squarred error ",
    yaxis_type="log",
    yaxis_tickformat=".1e"
)

fig.write_image(mapped_root / "ControlMesh_MSEs.png")
fig.show()