In [1]:
import numpy as np
import torch
import random 
import plotly.express as px
from pathlib import Path
import pyvista as pv
import sys
import pandas as pd
import pint

In [2]:
sys.path.append(str(Path.cwd().parent))

In [3]:
from scr.offline_stage import NirbModule
from scr.comsol_module.comsol_classes import COMSOL_VTU
from scr.comsol_module.helper import calculate_normal
from scr.utils import load_pint_data, inverse_min_max_scaler, standardize, safe_parse_quantity

In [4]:
PARAMETER_SPACE = "01"
ACCURACY = 1e-4
ROOT = Path.cwd().parent / "data" / PARAMETER_SPACE
ROOT.exists()

True

In [5]:
comsol_data = COMSOL_VTU(ROOT / "Training/Training_000.vtu")

### Load checkpoints from different trainings

In [6]:
chk_pt_paths = [path for path in (ROOT / f"nn_logs_{ACCURACY:.1e}").rglob("*.ckpt")]
chk_pt_paths = sorted(chk_pt_paths, key = lambda p: p.parent.parent.stem)

for idx, path in enumerate(chk_pt_paths):
    print(f"{idx:02d} ({path.parent.parent.stem}): {path.name}")

00 (version_0): epoch=149999-step=300000.ckpt


In [7]:
# ckpt = torch.load(chk_pt_paths[10], map_location="cpu")
# for key in ckpt['state_dict'].keys():
#     print(key)
# for k, v in ckpt['state_dict'].items():
#     print(k, v.shape)
# print(ckpt.get("hyper_parameters", {}))
# print(ckpt.get("hyper_parameters", {}))

### Load checkpoint

In [8]:
# Select checkpoint
chk_pt_path = chk_pt_paths[0]

In [9]:
trained_model = NirbModule.load_from_checkpoint(chk_pt_path)
trained_model = trained_model.to('cpu')
trained_model.eval()

NirbModule(
  (loss): MSELoss()
  (activation): Sigmoid()
  (model): NIRB_NN(
    (layers): Sequential(
      (0): Linear(in_features=2, out_features=24, bias=True)
      (1): Sigmoid()
      (2): Linear(in_features=24, out_features=193, bias=True)
      (3): Sigmoid()
      (4): Linear(in_features=193, out_features=193, bias=True)
      (5): Sigmoid()
      (6): Linear(in_features=193, out_features=101, bias=True)
      (7): Sigmoid()
      (8): Linear(in_features=101, out_features=96, bias=True)
      (9): Sigmoid()
      (10): Linear(in_features=96, out_features=11, bias=True)
    )
  )
  (msa_metric): MeanAbsoluteError()
)

In [10]:
basis_functions             = np.load(ROOT / "BasisFunctions" / f"basis_fts_matrix_{ACCURACY:.1e}.npy")
training_snapshots_npy      = np.load(ROOT / "Exports" / "Training_temperatures.npy")
training_parameters     = load_pint_data(ROOT / "training_samples.csv", is_numpy = True)
min_max                     = np.load(ROOT / "BasisFunctions" / "min_max.npy")

test_snapshots_npy            = np.load(ROOT / "Test" / "Test_temperatures.npy")
test_parameters               = load_pint_data(ROOT / "test_samples.csv", is_numpy = True)
test_parameters_pint           = load_pint_data(ROOT / "test_samples.csv")

In [11]:
param_folder = ROOT / "Exports"
param_files = sorted([file for file in param_folder.rglob("*.csv") if "test" in file.stem.lower()])
assert len(param_files) > 0

In [12]:
# Prepare data
mean = np.mean(test_parameters, axis=0)
var  = np.var(test_parameters, axis=0)
test_parameters_scaled = standardize(test_parameters, mean, var)
training_snapshots  = training_snapshots_npy[:, -1, :] # last time step
test_snapshots      = test_snapshots_npy[:, -1, :]

### Prepare Plot

In [13]:
N = 5
random.seed(3123)
samples = random.sample(range(len(test_snapshots)), N)
colors = px.colors.sample_colorscale("thermal", [n/(N -1) for n in range(N)])

### Plot training vs original

In [14]:
# Initialize the figure
n_cols = 2  # number of columns 
plotter = pv.Plotter(shape=(N, n_cols), window_size=(1300, 1200))
counter = 0
differences = np.zeros((N, len(comsol_data.mesh.points)))
# Loop through the samples and plot
for i in range(N):
    sample_idx = samples[i]
    parameters_df_file = param_files[sample_idx]
    param_df = pd.read_csv(param_files[idx], index_col = 0)
    ureg = pint.UnitRegistry()
    param_df['quantity_pint'] = param_df[param_df.columns[-1]].apply(lambda x : safe_parse_quantity(x, ureg))
    dip = param_df.loc['dip', 'quantity_pint', ]
    strike = param_df.loc['strike', 'quantity_pint']
    normal = calculate_normal(dip.to('deg').magnitude, strike.to('deg').magnitude)
    
    plotter.subplot(i, 0)
    counter += 1
    field_name = f"Test Sample {sample_idx}"
    param_string =  "\n".join([f"{col} = {para.magnitude:.2e} {para.units:~P}" for col, para in test_parameters_pint.loc[sample_idx].items()])

    
    # Predict 
    param = test_parameters_scaled[sample_idx]
    param_t = torch.from_numpy(param.astype(np.float32)).view(1, -1)
    res = trained_model(param_t)
    res_np = res.detach().numpy()
    my_sol = np.matmul(res_np.flatten(), basis_functions)
    my_sol = inverse_min_max_scaler(my_sol, min_max[0], min_max[1])
    
    comsol_data.mesh.point_data[field_name] = my_sol
    clipped = comsol_data.mesh.clip(normal=-np.array(normal), origin = comsol_data.mesh.center)
    
    plotter.add_mesh(clipped, scalars=field_name,
                        cmap='jet',
                        scalar_bar_args={'title': f'Temperature [K] ({counter})',
                                        'label_font_size': 10,
                                        'title_font_size': 8,})
    plotter.add_text(f"NIRB (Sample {sample_idx:03d})",
                     font_size=13)
    plotter.add_text(param_string,
                    position="left_edge",
                    font_size=9,)
    plotter.add_axes(line_width=1.)
    plotter.add_bounding_box()
    plotter.show_grid(
        font_size=6,
        n_xlabels=3,  # number of labels (ticks) on x-axis
        n_ylabels=3,  # number of labels (ticks) on y-axis
        n_zlabels=3,  # number of labels (ticks) on z-axis
        color='gray',
        xtitle='',
        ytitle='',
        ztitle='',
                        )
    plotter.subplot(i, 1)
    counter += 1
    comsol_data.mesh.point_data[f"Test{sample_idx}"] = test_snapshots[sample_idx, :]
    clipped = comsol_data.mesh.clip(normal=-np.array(normal), origin = comsol_data.mesh.center)
    plotter.add_mesh(clipped, scalars=f"Test{sample_idx}",
                    cmap='jet',
                    scalar_bar_args={'title': f'Temperature [K] ({counter})',
                                    'label_font_size': 10,
                                    'title_font_size': 8,})
    plotter.add_text(f"Original (Sample {sample_idx:03d})",
                     font_size=13)
    plotter.add_text(param_string,
                     position="left_edge",
                     font_size=9,)
    plotter.add_axes(line_width=1.)
    plotter.add_bounding_box()
    plotter.show_grid(
        font_size=6,
        n_xlabels=3,  # number of labels (ticks) on x-axis
        n_ylabels=3,  # number of labels (ticks) on y-axis
        n_zlabels=3,  # number of labels (ticks) on z-axis
        color='gray',
        xtitle='',
        ytitle='',
        ztitle='',
                        )
    
    
    differences[i, :] = test_snapshots[sample_idx, :] - my_sol
    
plotter.show(screenshot=ROOT / f"ComparisonNirbOriginal_{ACCURACY:.1e}_PS{PARAMETER_SPACE}_{chk_pt_path.parents[1].stem}.png")


Widget(value='<iframe src="http://localhost:62516/index.html?ui=P_0x3124bad80_0&reconnect=auto" class="pyvista…

### Difference Plots

In [17]:
# Initialize the figure
plotter = pv.Plotter(shape=(N, 1), window_size=(1300, 1200))
counter = 0
# Loop through the samples and plot
for i in range(N):
    sample_idx = samples[i]
    plotter.subplot(i, 0)
    counter += 1
    param_string =  "\n".join([f"{col} = {para.magnitude:.2e} {para.units:~P}" for col, para in test_parameters_pint.loc[sample_idx].items()])

    field_name = f"Difference (Reduced vs Full Solution) Sample {sample_idx}"
    
    comsol_data.mesh.point_data['diff'] = differences[i, :]
    clipped = comsol_data.mesh.clip(normal=-np.array(normal), origin = comsol_data.mesh.center)
    
    plotter.add_mesh(clipped, scalars='diff',
                        cmap='jet',
                        scalar_bar_args={'title': f'Diff Temperature [K] ({counter})',
                                        'label_font_size': 10,
                                        'title_font_size': 8,})
    plotter.add_text(f"Difference (Sample {sample_idx:03d})",
                     font_size=13)
    plotter.add_text(param_string,
                    position="left_edge",
                    font_size=9,)
    plotter.add_axes(line_width=1.)
    plotter.add_bounding_box()
    plotter.show_grid(
        font_size=6,
        n_xlabels=3,  # number of labels (ticks) on x-axis
        n_ylabels=3,  # number of labels (ticks) on y-axis
        n_zlabels=3,  # number of labels (ticks) on z-axis
        color='gray',
        xtitle='',
        ytitle='',
        ztitle='',
                        )

    
plotter.show(screenshot=ROOT / f"ComparisonNirbOriginal_Diff_{ACCURACY:.1e}_PS_{PARAMETER_SPACE}_{chk_pt_path.parents[1].stem}.png")


Widget(value='<iframe src="http://localhost:61243/index.html?ui=P_0x3ae2b7b30_3&reconnect=auto" class="pyvista…