In [3]:
import numpy as np
from pathlib import Path
from helpers import inverse_min_max_scaler
import pyvista as pv

In [4]:
import sys
sys.path.append(str(Path().cwd().parent.parent))
from ComsolClasses.comsol_classes import COMSOL_VTU
from ComsolClasses.helper import calculate_normal

In [5]:
VERSION = "01"
ROOT = Path().cwd().parent / "Snapshots" / VERSION / "BasisFunctions"
ROOT.exists()

True

### Import

In [6]:
min_max = np.load(ROOT / "min_max.npy")
basis_functions = np.load(ROOT / "basis_fts_matrix.npy")
information_content = np.load(ROOT / "information_content.npy")

In [7]:
basis_functions_scaled = inverse_min_max_scaler(basis_functions, min_max[0], min_max[1])

In [8]:
comsol_data = COMSOL_VTU("/Users/thomassimader/Documents/NIRB/Snapshots/01/Training/Training_000.vtu")

In [9]:
normal = calculate_normal(60, 90)

### Plot

In [10]:
n_cols = 2  # number of columns you want
n_rows = int(np.ceil(basis_functions.shape[0] / n_cols))  # calculate required number of rows

plotter = pv.Plotter(shape=(n_rows, n_cols), window_size=(1300, 1200), title=f"Basis functions - Parameter Space {VERSION}")
counter = 0
for i in range(n_rows):
    for j in range(n_cols):
        plotter.subplot(i, j)
        field_name = f"basis_function{i}"
        if counter >= basis_functions.shape[0]:
            break
        comsol_data.mesh.point_data[field_name] = basis_functions_scaled[counter, :]
        clipped = comsol_data.mesh.clip(normal=-np.array(normal), origin = comsol_data.mesh.center)
        
        plotter.add_mesh(clipped, scalars=field_name,
                            cmap='jet',
                            scalar_bar_args={'title': f'Temperature [K] ({counter})',
                                            'label_font_size': 10,
        
                                            'title_font_size': 8,})
        text_string = f"Basis Function {counter + 1}"
        try:
            text_string = text_string + f"({information_content[counter] * 100:.2f} %)"
        except NameError:
            pass
        plotter.add_text(text_string)
        plotter.add_axes(line_width=1.)
        plotter.add_bounding_box()
        plotter.show_grid(
            font_size=6,
            # location='outer',
            n_xlabels=3,  # number of labels (ticks) on x-axis
            n_ylabels=3,  # number of labels (ticks) on y-axis
            n_zlabels=3,  # number of labels (ticks) on z-axis
            color='gray',
            xtitle='',
            ytitle='',
            ztitle='',
                          )
        counter += 1

plotter.show(screenshot=ROOT / "basis_functions.png")

Widget(value='<iframe src="http://localhost:61220/index.html?ui=P_0x1236b2ed0_0&reconnect=auto" class="pyvista…