# 📊 Notebook 3: Error Analysis; Visualizing and Comparing Predicted Excited-State Surfaces

In this notebook, we analyze the machine learning model predictions and compare them against **quantum chemical (QM) reference data**. The focus is on understanding how well the models reproduce key features of the excited-state landscape of $\mathrm{CH_2NH_2^+}$.

### 📐 The Geometry Grid

The molecular geometries used here lie on a **two-dimensional grid** that systematically varies:

- The **C–N bond length** (elongation),
- The **torsion angle around the C–N bond** (rotation).

This scan is designed to probe a chemically relevant subspace of configurations where important photophysical and photochemical phenomena—such as **state crossings** and **conical intersections**—are expected to occur.

### 📈 What We Do in This Notebook

- Load the **predicted energies** from both ML models (**SchNet** and **PaiNN**) and the **QM ground truth**,
- Render the **low-lying adiabatic states** (typically $S_0$, $S_1$, $S_2$) as 3D surfaces over the 2D scan,
- Visualize and compare how the models perform in capturing key features such as:
  - Relative energy ordering,
  - Topography of excited-state PESs,
  - Presence of avoided crossings or near-degeneracies.
- Visulization of the test statistics for the two models

This type of comparison provides crucial insight into **how faithfully machine-learned models approximate excited-state electronic structure**, particularly in regions where the surfaces are strongly coupled or non-trivial.

These visualizations also give us a first intuitive sense of **where a conical intersection might be located**—typically near regions where the $S_0$ and $S_1$ surfaces come very close in energy across a distorted geometry.



### Imports and Constants

In [None]:
import ase
from ase.db import connect
import nglview
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import ImageGrid

from spainn.interface import NacCalculator
from schnetpack.transform import MatScipyNeighborList

from tqdm import tqdm

from sklearn.metrics import r2_score
import scipy

In [None]:
BOHR_TO_ANGSTROM = 0.529177249
HARTREE_TO_meV = 27.2114 * 1e3

properties = {
    "energy" : {
        "vmin" : 0.0,
        "vmax" : 1000.0,
        "name" : "Energy",
        "label": ["$_1$","$_2$","$_3$"],
        "title": "Energies"
	},
    "nacs" : {
        "vmin" : -2.0,
        "vmax" : 2.0,
        "name" : "NACV",
        "label": ["$_{01}$","$_{02}$","$_{12}$"],
        "title": "NACs"
	},
}

In [None]:
def make_partity_plot(targ_dict: dict, 
                      pred_dict: dict):

    properties = ['E', 'F', 'NAC']
    nac_labels = ["$_{01}$","$_{02}$","$_{12}$"]
    
    labelsize     = 14
    titelsize     = labelsize * 1.2
    pad           = labelsize / 3
    tickwidth     = 3
    maj_tick_size = 6
    min_tick_size = 3
    
    units = {"E": r"meV / atom",
             "F": r"meV / Å",
             "NAC": r"meV / Å"}
    
    min_max = {"E" : np.array([1.e6, -1.e6]),
               "F": np.array([1.e6, -1.e6]),
               "NAC": np.array([1.e6, -1.e6])}
    
    fig, axs = plt.subplots(3, 3, figsize=(10,10), layout="constrained")
    for idx, prop in enumerate(properties):
        for state in range(3):
    
            if prop == 'E':
                conv_fac = HARTREE_TO_meV
                targ_min = targ_dict[prop][:, 0].min()
                targ = targ_dict[prop][:, state].flatten() 
                targ -= targ_min
                targ *= conv_fac / targ_dict['F'].shape[1]
                pred = pred_dict[prop][:, state].flatten() 
                pred -= targ_min
                pred *= conv_fac / targ_dict['F'].shape[1]
            else:
                conv_fac = HARTREE_TO_meV / BOHR_TO_ANGSTROM
                targ = targ_dict[prop][:, :, state, :].flatten() * conv_fac
                pred = pred_dict[prop][:, :, state, :].flatten() * conv_fac
                if prop == 'NAC':
                    # flip sign on prediction, since NAC phase is arbitrary
                    for idx_conf, single_targ in enumerate(targ):
                        pred[idx_conf] *= -1 if (np.linalg.norm(single_targ + pred[idx_conf]) 
                                                 < np.linalg.norm(single_targ - pred[idx_conf])) else 1
    
            min_max[prop][0] = min(min_max[prop][0], targ.min(), pred.min())
            min_max[prop][1] = max(min_max[prop][1], targ.max(), pred.max())
    
            # get stats
            pearson_r, p = scipy.stats.pearsonr(targ, pred)
            R2 = r2_score(targ, pred)
            mae = abs(pred-targ).mean()
            rmse = np.sqrt(np.power(pred-targ, 2).mean())
    
            ax = axs[idx, state]
            # do plotting
            hb = ax.hexbin(targ, pred,
                               cmap='Blues',
                               gridsize=25,
                               mincnt=1,
                               bins="log",
                               edgecolors=None,
                               linewidths=(0.2,),
                               xscale="linear",
                               yscale="linear",
                               )
            
            ax.text(0.05, 0.9, 'MAE: %.2f' % (mae),
                   transform=ax.transAxes, fontsize=labelsize*0.9,
                   zorder=10)
            ax.text(0.05, 0.8, 'RMSE: %.2f' % (rmse),
                   transform=ax.transAxes, fontsize=labelsize*0.9,
                   zorder=10)
            ax.text(0.6, 0.2, r'$R^2$: %.2f' % (R2),
               transform=ax.transAxes, fontsize=labelsize*0.9,
               zorder=10)
            ax.text(0.6, 0.1, r'$\rho$: %.2f' % (pearson_r),
                   transform=ax.transAxes, fontsize=labelsize*0.9,
                   zorder=10)
    
        for jj,ax in enumerate(axs[idx]):
            if idx != 2:
                ax.set_xlim(min_max[prop]*1.1)
                ax.set_ylim(min_max[prop]*1.1)
                #ax.set_aspect('equal')
            ax.spines['bottom'].set_linewidth(tickwidth)
            ax.spines['top'].set_linewidth(tickwidth)
            ax.spines['left'].set_linewidth(tickwidth)
            ax.spines['right'].set_linewidth(tickwidth)
            ax.tick_params(axis='x', length=maj_tick_size, width=tickwidth,
                           labelsize=labelsize, pad=pad,
                           direction='in')
            if jj >0:
                ax.set_yticklabels([])
            ax.tick_params(axis='y', length=maj_tick_size, width=tickwidth,
                   labelsize=labelsize, pad=pad,
                   direction='in')
    
        # titles
        for idx, ax in enumerate(axs[0]):
            ax.set_title("$E_{S_%i}$" % idx, fontsize=titelsize, pad=2*pad)
        for idx, ax in enumerate(axs[1]):
            ax.set_title(r"$\partial E_{S_%i} / \partial \mathbf{R}$" % idx, fontsize=titelsize, pad=2*pad)
        for idx, ax in enumerate(axs[2]):
            ax.set_title("NAC"+nac_labels[idx], fontsize=titelsize, pad=2*pad)
    
        # axis labels
        for idx, ax in enumerate(axs[0]):
            ax.set_xlabel(f"Target [{units['E']}]", fontsize=labelsize)
            if idx==0:
                ax.set_ylabel(f"Prediction [{units['E']}]", fontsize=labelsize)
        for idx, ax in enumerate(axs[1]):
            ax.set_xlabel(f"Target [{units['F']}]", fontsize=labelsize)
            if idx==0:
                ax.set_ylabel(f"Prediction [{units['F']}]", fontsize=labelsize)
        for idx, ax in enumerate(axs[2]):
            ax.set_xlabel(f"Target [{units['NAC']}]", fontsize=labelsize)
            if idx==0:
                ax.set_ylabel(f"Prediction [{units['NAC']}]", fontsize=labelsize)
    
    plt.show()

### ML vs. ML vs. QM for the 2D grid

In [None]:
bond_values = np.linspace(2.4321, 4.4321, 101) * BOHR_TO_ANGSTROM
angle_values = np.linspace(0, 90, 91)
X, Y = np.meshgrid(angle_values, bond_values)

In [None]:
preds_painn = np.load("Predictions_Painn.npz")
preds_schnet = np.load("Predictions_Schnet.npz")
targets     = np.load("groundtruth_grid.npz")
mask = targets['energy'][:,:,0] < -1 # geometries where the SCF did not converge

In [None]:
fig, axs = plt.subplots(1, 3, subplot_kw={"projection": "3d"},
                        figsize=(15, 6), layout='constrained')

for state in range(3):
    Z = preds_painn['energy'][:, state]
    surf = axs[0].plot_trisurf(X.flatten(), Y.flatten(), Z.flatten(), cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

    Z = preds_schnet['energy'][:, state]
    surf = axs[1].plot_trisurf(X.flatten(), Y.flatten(), Z.flatten(), cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

    Z = targets['energy'][:, :, state]
    surf = axs[2].plot_trisurf(X[mask].flatten(), Y[mask].flatten(), Z[mask].flatten(), cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

axs[0].set_title("Painn", fontsize=20)
axs[1].set_title("Schnet", fontsize=20)
axs[2].set_title("QM", fontsize=20)

for ax in axs:
    ax.set_ylabel("C-N bond length [Å]")
    ax.set_xlabel("C-N rotation [°]")
    ax.set_zlabel("Energy / Ha")

plt.show()

## Zooming in on the interaction between the 2 lowest states

---

### 🔍 Energy Gap Analysis: $S_1 - S_0$

After visualizing the full energy landscapes, we zoom in on the **energy gap between the ground state ($S_0$) and the first excited state ($S_1$)**. This gap is particularly important because it governs the likelihood of nonadiabatic transitions and determines where surface hopping or internal conversion may occur.

#### ⚠️ A Note on the QM Reference

It’s worth noting that for **very stretched C–N bonds**, the QM reference data shows **non-smooth behavior** in the energy curves—likely due to convergence issues or the multireference character of the wavefunction becoming more pronounced. As a result, while the QM data is generally reliable, some caution is needed when interpreting the absolute values in these stretched regions.

---

### 🧩 Comparing ML and QM Gaps

We then plot the **$S_1 - S_0$ gap**:

- As a **2D surface** over the grid (C–N bond length vs. torsion),
- As a **1D slice** at **fixed torsion angle ≈ 90°**, where the molecule is strongly twisted.

#### ✅ What Works Well

Both SchNet and PaiNN models correctly reproduce the **overall topology** of the gap landscape. In particular:

- The location and magnitude of the gap minimum are broadly consistent with QM,
- The increasing gap with C–N bond compression or extension is captured.

#### ❌ Limitations and Artifacts

In the 1D slice near **90° torsion**, both ML models display **incorrect avoided crossings** that are **not present** in the QM reference. This is a known challenge in modeling excited states with ML—getting the correct **state ordering and topological features** in regions of strong coupling is difficult, especially without explicitly enforcing the physics of conical intersections.

Such artifacts remind us that while ML models can generalize well in smooth regions of PESs, they may struggle in:

- Regions with near-degeneracies,
- Topologically nontrivial crossings (e.g., true or avoided),
- Cases where the ordering of electronic states changes abruptly.

Understanding these limitations is essential before using ML models in dynamics or control applications.

### 2D energy gap surface

In [None]:
fig, axs = plt.subplots(1, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 5))

levels = [i*500.0 for i in range(20)]

# painn
axs[0].contourf(X, Y, HARTREE_TO_meV*(preds_painn['energy'][:,1] - preds_painn['energy'][:,0]).reshape(101, 91), 
                        cmap='Blues', 
                       levels=levels)
axs[0].set_title('PaiNN', fontsize=20)

# schnet
axs[1].contourf(X, Y, HARTREE_TO_meV*(preds_schnet['energy'][:,1] - preds_schnet['energy'][:,0]).reshape(101, 91), 
                        cmap='Blues', 
                       levels=levels)
axs[1].set_title('SchNet', fontsize=20)

# ground truth
S0_energy = np.ma.masked_where(~mask, targets['energy'][:,:,0])
S1_energy = np.ma.masked_where(~mask, targets['energy'][:,:,1])
S2_energy = np.ma.masked_where(~mask, targets['energy'][:,:,2])
cont3 = axs[2].contourf(X, Y, HARTREE_TO_meV*(S1_energy - S0_energy).data, 
                        cmap='Blues', 
                        levels=levels)
axs[2].set_title('QM', fontsize=20)

cbar = fig.colorbar(cont3, ax=axs[2])
cbar.set_label("S$_1$-S$_0$ / meV", fontsize=15)
cbar.ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
for spine in cbar.ax.spines:
    cbar.ax.spines[spine].set_linewidth(2)


for ax in axs:
    ax.set_xlabel("C-N rotation [°]", fontsize=15)
    for spine in ax.spines:
        ax.spines[spine].set_linewidth(2)
    ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')

axs[0].set_ylabel("C-N bond length [Å]", fontsize=15)

plt.show()

---
## 1D slices

#### Flat geometry

In [None]:
fig, axs = plt.subplots(1, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 5))

min_energy = S0_energy.min()

# painn
axs[0].plot(bond_values, HARTREE_TO_meV*(preds_painn['energy'][:, 0].reshape(101, 91)[:, 0] - min_energy),
           label="$S_0$", linewidth=2, color=plt.get_cmap("plasma")(0))
axs[0].plot(bond_values, HARTREE_TO_meV*(preds_painn['energy'][:, 1].reshape(101, 91)[:, 0]- min_energy),
           label="$S_1$", linewidth=2, color=plt.get_cmap("plasma")(100))
axs[0].plot(bond_values, HARTREE_TO_meV*(preds_painn['energy'][:, 2].reshape(101, 91)[:, 0]- min_energy),
           label="$S_2$", linewidth=2, color=plt.get_cmap("plasma")(200))
axs[0].set_title('PaiNN', fontsize=20)


# schnet
axs[1].plot(bond_values, HARTREE_TO_meV*(preds_schnet['energy'][:, 0].reshape(101, 91)[:, 0] - min_energy),
           label="$S_0$", linewidth=2, color=plt.get_cmap("plasma")(0))
axs[1].plot(bond_values, HARTREE_TO_meV*(preds_schnet['energy'][:, 1].reshape(101, 91)[:, 0] - min_energy),
           label="$S_1$", linewidth=2, color=plt.get_cmap("plasma")(100))
axs[1].plot(bond_values, HARTREE_TO_meV*(preds_schnet['energy'][:, 2].reshape(101, 91)[:, 0] - min_energy),
           label="$S_2$", linewidth=2, color=plt.get_cmap("plasma")(200))
axs[1].set_title('SchNet', fontsize=20)


# QM
axs[2].plot(bond_values, HARTREE_TO_meV*(S0_energy[:, 0] - min_energy),
           label="$S_0$", linewidth=2, color=plt.get_cmap("plasma")(0))
axs[2].plot(bond_values, HARTREE_TO_meV*(S1_energy[:, 0]- min_energy),
           label="$S_1$", linewidth=2, color=plt.get_cmap("plasma")(100))
axs[2].plot(bond_values, HARTREE_TO_meV*(S2_energy[:, 0]- min_energy),
           label="$S_2$", linewidth=2, color=plt.get_cmap("plasma")(200))
axs[2].set_title('QM', fontsize=20)

for ax in axs:
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['top'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)
    ax.spines['right'].set_linewidth(3)
    ax.tick_params(length=6, width=3,
                   labelsize=12, pad=6,
                   direction='in')
    ax.set_xlabel("C-N bond length [Å]", fontsize=15)
    ax.legend(loc='best', frameon=False, fontsize=15)
    ax.axvline(1.37, linewidth=1, color='gray', alpha=0.5)

axs[0].set_ylabel("Relative Energy / meV", fontsize=15)
axs[0].set_xlim([1.25, 1.55])
axs[0].set_ylim([7000, 10000])

plt.show()

#### Twisted geometry

In [None]:
fig, axs = plt.subplots(1, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 5))

min_energy = S0_energy.min()

# Painn
axs[0].plot(bond_values, HARTREE_TO_meV*(preds_painn['energy'][:, 0].reshape(101, 91)[:, -1] - min_energy),
           label="$S_0$", linewidth=2, color=plt.get_cmap("plasma")(0))
axs[0].plot(bond_values, HARTREE_TO_meV*(preds_painn['energy'][:, 1].reshape(101, 91)[:, -1]- min_energy),
           label="$S_1$", linewidth=2, color=plt.get_cmap("plasma")(100))
axs[0].set_title('PaiNN', fontsize=20)


# Schnet
axs[1].plot(bond_values, HARTREE_TO_meV*(preds_schnet['energy'][:, 0].reshape(101, 91)[:, -1] - min_energy),
           label="$S_0$", linewidth=2, color=plt.get_cmap("plasma")(0))
axs[1].plot(bond_values, HARTREE_TO_meV*(preds_schnet['energy'][:, 1].reshape(101, 91)[:, -1] - min_energy),
           label="$S_1$", linewidth=2, color=plt.get_cmap("plasma")(100))
axs[1].set_title('SchNet', fontsize=20)


# QM
axs[2].plot(bond_values, HARTREE_TO_meV*(S0_energy[:, -1] - min_energy),
           label="$S_0$", linewidth=2, color=plt.get_cmap("plasma")(0))
axs[2].plot(bond_values, HARTREE_TO_meV*(S1_energy[:, -1]- min_energy),
           label="$S_1$", linewidth=2, color=plt.get_cmap("plasma")(100))
axs[2].set_title('QM', fontsize=20)

for ax in axs:
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['top'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)
    ax.spines['right'].set_linewidth(3)
    ax.tick_params(length=6, width=3,
                   labelsize=12, pad=6,
                   direction='in')
    ax.set_xlabel("C-N bond length [Å]", fontsize=15)
    ax.legend(loc='best', frameon=False, fontsize=15)
    ax.axvline(1.39, linewidth=1, color='gray', alpha=0.5)

axs[0].set_ylabel("Relative Energy / meV", fontsize=15)
axs[0].set_xlim([1.25, 1.75])
axs[0].set_ylim([3000, 6000])

plt.show()

---

### 🔁 Non-Adiabatic Couplings: Visualizing NAC Intensity

In addition to energies, both the QM reference and the ML models predict **non-adiabatic coupling vectors (NACs)** between electronic states. These are key quantities that drive nonradiative transitions in excited-state dynamics.

However, NACs are high-dimensional: for each pair of electronic states, the NAC is a set of **$N_\text{atoms} \times 3$** vectors—one 3D vector per atom. This makes direct visualization over the 2D grid difficult.

#### 📉 Strategy: Norm-Based NAC Summarization

To visualize NACs across the scan, we compute a **single scalar value** per geometry by:

1. Taking the **norm of each atomic 3-vector** (i.e., over x, y, z),
2. Summing these norms **across all atoms**.

This gives a **scalar measure of the overall NAC strength** between a pair of states at each geometry. We apply this procedure to the **$S_0 \leftrightarrow S_1$ coupling**, which is typically the most relevant for early-time nonadiabatic transitions.

---

### 🔬 Observations

- In the **QM reference**, the NAC norm is **sharply peaked** in a specific region of the scan:
  - **Large torsion angles** (near 90°),
  - **Short C–N bond lengths**.
  - This localization is consistent with a nearby **conical intersection**, where strong state mixing occurs.

- In contrast, both **ML models** (SchNet and PaiNN/SpaiNN) show:
  - A **broader, less localized NAC intensity**,
  - No sharp peak in the same region,
  - A tendency to **over-distribute** coupling strength across the scan.

This behavior reflects a known challenge: while ML models can be trained to approximate NACs, they often lack the **sharp localization** and **topological accuracy** seen in QM calculations—especially when trained only on pointwise quantities and without explicit conical intersection constraints.

---

### 🧠 Interpretation

The NAC norm plots are a useful diagnostic: they show **where the models believe nonadiabatic transitions are likely to occur**. Discrepancies in peak location and shape should be carefully considered when using these models for surface hopping or photochemical predictions.



In [None]:
fig, axs = plt.subplots(1, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 5))

levels = [i*1.0 for i in range(21)]


axs[0].contourf(X, Y, preds_painn['nacs'][:,0].reshape(101, 91), 
                        cmap='Blues', levels=levels)
axs[0].set_title('PaiNN', fontsize=20)


axs[1].contourf(X, Y, preds_schnet['nacs'][:,0].reshape(101, 91), 
                        cmap='Blues', levels=levels)
axs[1].set_title('SchNet', fontsize=20)


NAC_01 = np.ma.array(targets['nacs'][:,:,0], mask=targets['nacs'][:,:,0] > 1000)

cont1 = axs[2].contourf(X, Y, NAC_01, 
                        cmap='Blues', 
                        levels=levels)
cbar = fig.colorbar(cont1, ax=axs[2])
cbar.set_label("sum(norm(NAC$_{01}$))", fontsize=12)
axs[2].set_title('QM', fontsize=20)

cbar.ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
for spine in cbar.ax.spines:
    cbar.ax.spines[spine].set_linewidth(2)


for ax in axs:
    ax.set_xlabel("C-N rotation [°]", fontsize=15)
    for spine in ax.spines:
        ax.spines[spine].set_linewidth(2)
    ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')

axs[0].set_ylabel("C-N bond length [Å]", fontsize=15)

plt.show()

---
### 📉 Comparing SchNet and PaiNN Prediction Errors

To assess the performance of the two ML models, we compare their predictions to the quantum chemical (QM) ground truth.

#### 🔺 Energy Errors

We compute the **absolute deviation** between the predicted and reference energies for each state at every geometry. As expected, the **prediction error increases with electronic state index**. This is a common trend — higher excited states tend to be:

- **Harder to model**, due to more complex electronic structure,
- **Less reliable** in the training data itself, since many QM methods lose accuracy with increasing excitation level.

#### 🔁 NAC Errors

For the **nonadiabatic coupling vectors (NACs)**, we again reduce each vector field to a **single scalar value** per geometry by:

1. Computing the **norm over the x, y, z components** for each atom,
2. Summing these values across all atoms.

This gives a rough estimate of **NAC magnitude**, but does **not capture orientation or phase**, which are also physically important.

To compare models, we plot the **logarithm of the absolute deviation** between the ML-predicted NAC norms and the QM reference. The logarithmic scale is necessary because NAC magnitudes can vary over several orders of magnitude, especially near conical intersections or avoided crossings.


#### Energies

In [None]:
prop = "energy"

prop_targ = targets[prop]
prop_targ1 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ2 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ1 -= preds_painn[prop].reshape(101, 91, 3)
prop_targ2 -= preds_schnet[prop].reshape(101, 91, 3)
prop_targ1 = HARTREE_TO_meV*np.abs(prop_targ1)
prop_targ2 = HARTREE_TO_meV*np.abs(prop_targ2)

fig, axs = plt.subplots(2, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 6))

cmap = cm.Reds
targ_idx = [0,1,2]

for column in range(3):
    ax = axs[0, column]
    im = ax.imshow(prop_targ1[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("PaiNN - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column], fontsize=15)

    if column==2:
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label("$\Delta$E / meV", fontsize=12)
        cbar.ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
        for spine in cbar.ax.spines:
            cbar.ax.spines[spine].set_linewidth(2)

    ax = axs[1, column]
    im = ax.imshow(prop_targ2[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("SchNet - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column], fontsize=15)

    if column==2:
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label("$\Delta$E / meV", fontsize=12)
        cbar.ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
        for spine in cbar.ax.spines:
            cbar.ax.spines[spine].set_linewidth(2)

for ax in axs.flatten():
    for spine in ax.spines:
        ax.spines[spine].set_linewidth(2)
    ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')

for ax in axs[-1]:
    ax.set_xlabel("C-N rotation [°]", fontsize=12)
for ax in axs[:,0]:
    ax.set_ylabel("C-N bond length [Å]", fontsize=12)
    
plt.show()

#### NACs

In [None]:
prop = "nacs"

prop_targ = targets[prop]
prop_targ1 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ2 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ1 -= preds_painn[prop].reshape(101, 91, 3)
prop_targ2 -= preds_schnet[prop].reshape(101, 91, 3)
prop_targ1 = np.log10(np.abs(prop_targ1))
prop_targ2 = np.log10(np.abs(prop_targ2))

fig, axs = plt.subplots(2, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 6))

cmap = cm.Reds
targ_idx = [0,1,2]

for column in range(3):
    ax = axs[0, column]
    im = ax.imshow(prop_targ1[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("PaiNN - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column])

    if column==2:
        cbar = fig.colorbar(im, ax=ax)
        cbar.ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
        for spine in cbar.ax.spines:
            cbar.ax.spines[spine].set_linewidth(2)

    ax = axs[1, column]
    im = ax.imshow(prop_targ2[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("SchNet - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column])

    if column==2:
        cbar = fig.colorbar(im, ax=ax)
        cbar.ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
        for spine in cbar.ax.spines:
            cbar.ax.spines[spine].set_linewidth(2)


for ax in axs.flatten():
    for spine in ax.spines:
        ax.spines[spine].set_linewidth(2)
    ax.tick_params(length=6, width=2,
                   labelsize=12, pad=6,
                   direction='in')
    
for ax in axs[-1]:
    ax.set_xlabel("C-N rotation [°]", fontsize=12)
for ax in axs[:,0]:
    ax.set_ylabel("C-N bond length [Å]", fontsize=12)

plt.show()

---

### 📊 Parity Plots: Final Test Set Evaluation

As a final comparison, we evaluate the performance of the two ML models—**SchNet** and **PaiNN**—on their respective **test sets**, using **parity plots** for all predicted quantities.

#### 🧪 Setup

To do this, we:

- Load the full database of training data,
- Use the stored train/validation/test split indices to **identify the test set**,
- Run predictions **only on the test configurations**, since evaluating the training or validation data adds no new insight.

#### ⚖️ What Parity Plots Show

A **parity plot** compares predicted values (on the y-axis) to ground truth values (on the x-axis). Perfect predictions lie on the diagonal.

We generate parity plots for:

- **Energies** (per electronic state),
- **Forces** (on all atoms),
- **NAC magnitudes** (this time on all atoms).

#### 📌 Expectations

- **Energies**: Both models should show **decent agreement** with QM values, especially for the ground and first excited states. As expected, errors increase for higher states.
  
- **Forces**: Prediction quality is reasonable, though with slightly more scatter than for energies. This is acceptable for many dynamics applications.

- **NACs**: The **PaiNN model** captures the overall trend and magnitudes reasonably well. In contrast, **SchNet performs poorly**, with large deviations and poor correlation—highlighting its limitations in learning **vector-valued, orientation-sensitive quantities** like NACs.

These results underline the importance of model architecture: PaiNN, being **equivariant**, is better suited for vector quantities like forces and NACs, whereas SchNet lacks the built-in geometric structure needed for reliable NAC predictions.



In [None]:
test_indices_painn = np.load("Painn_model/train_val_test_indices.npz")['test_idx']
test_indices_schnet = np.load("Schnet_model/train_val_test_indices.npz")['test_idx']
n_test = len(test_indices_painn)

In [None]:
db = connect("methylenimmonium.db")

In [None]:
# NOTE: forces and nacs have the shape (Natoms, Nstates, xyz) -> here (6, 3, 3)
pred_painn = {'E': np.zeros(shape=(n_test, 3)),
              'F': np.zeros(shape=(n_test, 6, 3, 3)),
              'NAC': np.zeros(shape=(n_test, 6, 3, 3))}
pred_schnet = {'E': np.zeros(shape=(n_test, 3)),
               'F': np.zeros(shape=(n_test, 6, 3, 3)),
               'NAC': np.zeros(shape=(n_test, 6, 3, 3))}

In [None]:
targ_painn = {'E': np.zeros(shape=(n_test, 3)),
              'F': np.zeros(shape=(n_test, 6, 3, 3)),
              'NAC': np.zeros(shape=(n_test, 6, 3, 3))}
targ_schnet = {'E': np.zeros(shape=(n_test, 3)),
               'F': np.zeros(shape=(n_test, 6, 3, 3)),
               'NAC': np.zeros(shape=(n_test, 6, 3, 3))}

#### Making the predictions and storing them

In [None]:
# PaiNN

calc = NacCalculator(model_file="Painn_model/best_model", neighbor_list=MatScipyNeighborList(cutoff=10.0))
atom = ase.Atoms(symbols="CNHHHH")
atom.calc = calc
for ii, idx in tqdm(enumerate(test_indices_painn)):
    row = db.get(int(idx)+1)
    atom.set_positions(row.positions)
    props = atom.get_properties(['energy', 'smooth_nacs', 'forces'])
    pred_painn['E'][ii] = props['energy']
    pred_painn['NAC'][ii] = props['smooth_nacs']
    pred_painn['F'][ii] = props['forces']

    targ_painn['E'][ii] = row.data['energy']
    targ_painn['NAC'][ii] = row.data['smooth_nacs']
    targ_painn['F'][ii] = row.data['forces']

In [None]:
# SchNet

calc = NacCalculator(model_file="Schnet_model/best_model", neighbor_list=MatScipyNeighborList(cutoff=10.0))
atom = ase.Atoms(symbols="CNHHHH")
atom.calc = calc
for ii, idx in tqdm(enumerate(test_indices_schnet)):
    row = db.get(int(idx)+1)
    atom.set_positions(row.positions)
    props = atom.get_properties(['energy', 'smooth_nacs', 'forces'])
    pred_schnet['E'][ii] = props['energy']
    pred_schnet['NAC'][ii] = props['smooth_nacs']
    pred_schnet['F'][ii] = props['forces']

    targ_schnet['E'][ii] = row.data['energy']
    targ_schnet['NAC'][ii] = row.data['smooth_nacs']
    targ_schnet['F'][ii] = row.data['forces']

---
### Plotting

#### SchNet

In [None]:
make_partity_plot(targ_schnet, pred_schnet)

#### PaiNN

In [None]:
make_partity_plot(targ_painn, pred_painn)

---

### 🎯 Expectations vs. Reality

- **Energies and Forces**: As expected, both models produce **reasonable predictions**, but the performance is **significantly worse than for models trained only on the electronic ground state**. This is likely due to **competing terms in the loss function**, where energy, force, and NAC targets for multiple states must all be balanced. The added complexity makes optimization more difficult and may limit per-task accuracy.

- **NACs**: The NAC predictions are **consistently poor** for both models:
  - Parity plots show **some correlation** with reference values,
  - **Pearson correlation coefficients** are near 0.5 only $S_0 \leftrightarrow S_2$ is better fitted,
  - **$R^2$ scores are negative** for all NAC components.

This confirms that while equivariant models like PaiNN can theoretically represent NACs, **in practice**, current training strategies and data coverage are **insufficient to capture their directional and phase-sensitive structure**.

Improving NAC prediction likely requires:
- More targeted loss formulations,
- Better representation of phase and sign consistency,
- Possibly hybrid architectures that combine ML with physics-based constraints.
