# Testing regression model

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from models.utils import identify_device
from models.regression_utils import test_regression_model, load_regression_dataset, print_metrics_table, denormalize_material_params, normalize_material_params
from models.Regression_models import CNN1D_Regressor
from models.utils import display_model
from torch.nn import MSELoss

ImportError: cannot import name 'CNN1D_Regressor' from 'models.Regression_models' (c:\Users\VECSEL\Documents\Theo\ML\THz-TD-CNN\models\Regression_models.py)

In [None]:
device = torch.device('cpu')
model = CNN1D_Regressor()
trained_model_name = 'regression_model_train_on_0k_clean_lr0_001_epochs_200'
unseen_dataset_name = 'unseen_3_layer_nonoise_n1to8'
training_dataset = 'train_3_layer_nonoise_n1to8'


dataset = load_regression_dataset(f'regression_data/{unseen_dataset_name}.pt') # loads data and normalizes targets in dataset
unseen_loader = DataLoader(dataset, batch_size=1024)

model.load_state_dict(torch.load(f'trained_models/{trained_model_name}.pth', map_location=torch.device('cpu')))
model.to(device)

display_model(model, device=device)

In [None]:
metrics, results = test_regression_model(model, unseen_loader, device)
print_metrics_table(metrics, num_layers=3)

In [None]:
index = 1
test_example_pulse = dataset[index][0]
test_example_true_vals = dataset[index][1]


print(denormalize_material_params(test_example_true_vals))

model.eval()
with torch.no_grad():
    pulse = test_example_pulse.unsqueeze(0).to(device)  # adjust based on original shape
    pred = model(pulse)
    pred = pred[0]
print(test_example_true_vals)
print(pred)

In [None]:
# Which to plot: change keys here
preds = results['preds_unscaled'].clone()
targets = results['targets_unscaled'].clone()

# Parameter labels and types
param_labels = [
    r"$n_1$", r"$k_1$", r"$d_1$ [$\mu$m]",
    r"$n_2$", r"$k_2$", r"$d_2$ [$\mu$m]",
    r"$n_3$", r"$k_3$", r"$d_3$ [$\mu$m]"
]
thickness_indices = [2, 5, 8]  # positions of d1, d2, d3

# Convert thickness values from m → μm for plotting
preds_plot = preds.clone()
targets_plot = targets.clone()
preds_plot[:, thickness_indices] *= 1e6
targets_plot[:, thickness_indices] *= 1e6

# Create figure
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
axes = axes.flatten()

for i in range(9):
    ax = axes[i]
    sns.scatterplot(
        x=targets_plot[:, i],
        y=preds_plot[:, i],
        ax=ax,
        s=5, color='blue', alpha=0.6, edgecolor=None
    )

    # Add 45° perfect prediction line
    min_val = min(targets_plot[:, i].min(), preds_plot[:, i].min())
    max_val = max(targets_plot[:, i].max(), preds_plot[:, i].max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=1)

    ax.set_xlabel(f"Target {param_labels[i]}")
    ax.set_ylabel(f"Predicted {param_labels[i]}")
    ax.set_title(param_labels[i])
    ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()