# Plot and compare QM and ML data

### Imports and Constants

In [None]:
import ase
from ase.db import connect
import nglview
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import ImageGrid

from spainn.interface import NacCalculator
from schnetpack.transform import MatScipyNeighborList

from tqdm import tqdm

from sklearn.metrics import r2_score
import scipy

In [None]:
BOHR_TO_ANGSTROM = 0.529177249
HARTREE_TO_meV = 27.2114 * 1e3

## Comparing the angle-bond length grid ML vs. ML vs. QM

In [None]:
bond_values = np.linspace(2.4321, 4.4321, 101) * BOHR_TO_ANGSTROM
angle_values = np.linspace(0, 90, 91)
X, Y = np.meshgrid(angle_values, bond_values)

In [None]:
preds_painn = np.load("Predictions_Painn.npz")
preds_schnet = np.load("Predictions_Schnet.npz")
targets     = np.load("groundtruth_grid.npz")
mask = targets['energy'][:,:,0] < -1 # geometries where the SCF did not converge

In [None]:
fig, axs = plt.subplots(1, 3, subplot_kw={"projection": "3d"},
                        figsize=(15, 6), layout='constrained')

for state in range(3):
    Z = preds_painn['energy'][:, state]
    surf = axs[0].plot_trisurf(X.flatten(), Y.flatten(), Z.flatten(), cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

    Z = preds_schnet['energy'][:, state]
    surf = axs[1].plot_trisurf(X.flatten(), Y.flatten(), Z.flatten(), cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

    Z = targets['energy'][:, :, state]
    surf = axs[2].plot_trisurf(X[mask].flatten(), Y[mask].flatten(), Z[mask].flatten(), cmap=cm.coolwarm,
                           linewidth=0, antialiased=False)

axs[0].set_title("Painn")
axs[1].set_title("Schnet")
axs[2].set_title("QM")

for ax in axs:
    ax.set_ylabel("C-N bond length [°]")
    ax.set_xlabel("C-N rotation [Å]")
    ax.set_zlabel("Energy / Ha")

plt.show()

## Zooming in on the interaction between the 2 lowest states


### Energies
Visualizing the enrergy gap between the ground and first excited state

In [None]:
fig, axs = plt.subplots(1, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 5))

levels = [i*0.05 for i in range(8)]

# painn
axs[0].contourf(X, Y, (preds_painn['energy'][:,1] - preds_painn['energy'][:,0]).reshape(101, 91), 
                        cmap='Blues', 
                       levels=levels)
axs[0].set_title('PaiNN', fontsize=20)

# schnet
axs[1].contourf(X, Y, (preds_schnet['energy'][:,1] - preds_schnet['energy'][:,0]).reshape(101, 91), 
                        cmap='Blues', 
                       levels=levels)
axs[1].set_title('SchNet', fontsize=20)

# ground truth
S0_energy = np.ma.masked_where(~mask, targets['energy'][:,:,0])
S1_energy = np.ma.masked_where(~mask, targets['energy'][:,:,1])
cont3 = axs[2].contourf(X, Y, (S1_energy - S0_energy).data, 
                        cmap='Blues', 
                        levels=levels)
axs[2].set_title('QM', fontsize=20)

cbar = fig.colorbar(cont3, ax=axs[2])
cbar.set_label("S$_1$-S$_0$ / Ha", fontsize=15)


for ax in axs:
    ax.set_xlabel("C-N rotation [°]", fontsize=15)

axs[0].set_ylabel("C-N bond length [Å]", fontsize=15)

plt.show()

### NACs
Visualizing the nonadiabatic coupling vectors between the ground and first excited state

In [None]:
fig, axs = plt.subplots(1, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 5))

levels = [i*1.0 for i in range(21)]


axs[0].contourf(X, Y, preds_painn['nacs'][:,0].reshape(101, 91), 
                        cmap='Blues', levels=levels)
axs[0].set_title('PaiNN', fontsize=20)


axs[1].contourf(X, Y, preds_schnet['nacs'][:,0].reshape(101, 91), 
                        cmap='Blues', levels=levels)
axs[1].set_title('SchNet', fontsize=20)


NAC_01 = np.ma.array(targets['nacs'][:,:,0], mask=targets['nacs'][:,:,0] > 1000)

cont1 = axs[2].contourf(X, Y, NAC_01, 
                        cmap='Blues', 
                        levels=levels)
cbar = fig.colorbar(cont1, ax=axs[2])
cbar.set_label("sum(norm(NAC$_{01}$))", fontsize=12)
axs[2].set_title('QM', fontsize=20)


for ax in axs:
    ax.set_xlabel("C-N rotation [°]", fontsize=15)

axs[0].set_ylabel("C-N bond length [Å]", fontsize=15)

plt.show()

### Comparing SchNet and PaiNN prediction errors

We are computing the absolute deviation between the energy predictions and the ground truth. 

For the nonadiabatic coupling vectors, we have condensed them into a single number for each geometry by computing the norm over xyz and then the sum over all atoms. This roughly compares the magnitude of the vectors but not their oriation in space. We are then plotting the logarithm of the absolute deviation between these numbers, as the NACs can be quite large.

In [None]:
properties = {
    "energy" : {
        "vmin" : 0.0,
        "vmax" : 0.1,
        "name" : "Energy",
        "label": ["$_1$","$_2$","$_3$"],
        "title": "Energies"
	},
    "nacs" : {
        "vmin" : -2.0,
        "vmax" : 2.0,
        "name" : "NACV",
        "label": ["$_{01}$","$_{02}$","$_{12}$"],
        "title": "NACs"
	},
}

### Energies

In [None]:
prop = "energy"

prop_targ = targets[prop]
prop_targ1 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ2 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ1 -= preds_painn[prop].reshape(101, 91, 3)
prop_targ2 -= preds_schnet[prop].reshape(101, 91, 3)
prop_targ1 = np.abs(prop_targ1)
prop_targ2 = np.abs(prop_targ2)
double_diff = prop_targ1 - prop_targ2

fig, axs = plt.subplots(2, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 6))

cmap = cm.Reds
targ_idx = [0,1,2]

for column in range(3):
    ax = axs[0, column]
    im = ax.imshow(prop_targ1[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("PaiNN - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column])

    if column==2:
        cbar = fig.colorbar(im, ax=ax)

    ax = axs[1, column]
    im = ax.imshow(prop_targ2[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("SchNet - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column])

    if column==2:
        cbar = fig.colorbar(im, ax=ax)
        

# ax.cax.colorbar(im)

for ax in axs[-1]:
    ax.set_xlabel("C-N rotation [°]")
for ax in axs[:,0]:
    ax.set_ylabel("C-N bond length [Å]")
    


plt.show()

### NACs

In [None]:
prop = "nacs"

prop_targ = targets[prop]
prop_targ1 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ2 = np.ma.masked_where(targets[prop] == -1, prop_targ)
prop_targ1 -= preds_painn[prop].reshape(101, 91, 3)
prop_targ2 -= preds_schnet[prop].reshape(101, 91, 3)
prop_targ1 = np.log10(np.abs(prop_targ1))
prop_targ2 = np.log10(np.abs(prop_targ2))
double_diff = prop_targ1 - prop_targ2

fig, axs = plt.subplots(2, 3, layout="constrained", 
                        sharex=True, sharey=True,
                        figsize=(10, 6))

cmap = cm.Reds
targ_idx = [0,1,2]

for column in range(3):
    ax = axs[0, column]
    im = ax.imshow(prop_targ1[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("PaiNN - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column])

    if column==2:
        cbar = fig.colorbar(im, ax=ax)

    ax = axs[1, column]
    im = ax.imshow(prop_targ2[::-1,:,targ_idx[column]], cmap=cmap, aspect="auto", extent=[0, 90, 1.2870, 2.3453], 
                   vmin=properties[prop]["vmin"], vmax=properties[prop]["vmax"])
    ax.set_title("SchNet - "+r"$\Delta$"+properties[prop]["name"]+properties[prop]["label"][column])

    if column==2:
        cbar = fig.colorbar(im, ax=ax)
        

# ax.cax.colorbar(im)

for ax in axs[-1]:
    ax.set_xlabel("C-N rotation [°]")
for ax in axs[:,0]:
    ax.set_ylabel("C-N bond length [Å]")
    


plt.show()

## Test statistics

As a last comparison we look at the performance of the 2 models on their test sets. In general we expect decent (but not great) performance for energy and forces, and especially for SchNet terrible performance in predicting the nonadiabatic couplings.

First we need to compute the actual predictions for the two models

In [None]:
test_indices_painn = np.load("Painn_model/train_val_test_indices.npz")['test_idx']
test_indices_schnet = np.load("Schnet_model/train_val_test_indices.npz")['test_idx']
n_test = len(test_indices_painn)

In [None]:
db = connect("methylenimmonium.db")

In [None]:
# NOTE: forces and nacs have the shape (Natoms, Nstates, xyz) -> here (6, 3, 3)
pred_painn = {'E': np.zeros(shape=(n_test, 3)),
              'F': np.zeros(shape=(n_test, 6, 3, 3)),
              'NAC': np.zeros(shape=(n_test, 6, 3, 3))}
pred_schnet = {'E': np.zeros(shape=(n_test, 3)),
               'F': np.zeros(shape=(n_test, 6, 3, 3)),
               'NAC': np.zeros(shape=(n_test, 6, 3, 3))}

In [None]:
targ_painn = {'E': np.zeros(shape=(n_test, 3)),
              'F': np.zeros(shape=(n_test, 6, 3, 3)),
              'NAC': np.zeros(shape=(n_test, 6, 3, 3))}
targ_schnet = {'E': np.zeros(shape=(n_test, 3)),
               'F': np.zeros(shape=(n_test, 6, 3, 3)),
               'NAC': np.zeros(shape=(n_test, 6, 3, 3))}

#### Making the predictions and getting the targets from the DB

In [None]:
calc = NacCalculator(model_file="Painn_model/best_model", neighbor_list=MatScipyNeighborList(cutoff=10.0))
atom = ase.Atoms(symbols="CNHHHH")
atom.calc = calc
for ii, idx in tqdm(enumerate(test_indices_painn)):
    row = db.get(int(idx)+1)
    atom.set_positions(row.positions)
    props = atom.get_properties(['energy', 'smooth_nacs', 'forces'])
    pred_painn['E'][ii] = props['energy']
    pred_painn['NAC'][ii] = props['smooth_nacs']
    pred_painn['F'][ii] = props['forces']

    targ_painn['E'][ii] = row.data['energy']
    targ_painn['NAC'][ii] = row.data['smooth_nacs']
    targ_painn['F'][ii] = row.data['forces']

In [None]:
calc = NacCalculator(model_file="Schnet_model/best_model", neighbor_list=MatScipyNeighborList(cutoff=10.0))
atom = ase.Atoms(symbols="CNHHHH")
atom.calc = calc
for ii, idx in tqdm(enumerate(test_indices_schnet)):
    row = db.get(int(idx)+1)
    atom.set_positions(row.positions)
    props = atom.get_properties(['energy', 'smooth_nacs', 'forces'])
    pred_schnet['E'][ii] = props['energy']
    pred_schnet['NAC'][ii] = props['smooth_nacs']
    pred_schnet['F'][ii] = props['forces']

    targ_schnet['E'][ii] = row.data['energy']
    targ_schnet['NAC'][ii] = row.data['smooth_nacs']
    targ_schnet['F'][ii] = row.data['forces']

### Plotting

In [None]:
def make_partity_plot(targ_dict: dict, 
                      pred_dict: dict):

    properties = ['E', 'F', 'NAC']
    nac_labels = ["$_{01}$","$_{02}$","$_{12}$"]
    
    labelsize     = 14
    titelsize     = labelsize * 1.2
    pad           = labelsize / 3
    tickwidth     = 3
    maj_tick_size = 6
    min_tick_size = 3
    
    units = {"E": r"meV / atom",
             "F": r"meV / Å",
             "NAC": r"meV / Å"}
    
    min_max = {"E" : np.array([1.e6, -1.e6]),
               "F": np.array([1.e6, -1.e6]),
               "NAC": np.array([1.e6, -1.e6])}
    
    fig, axs = plt.subplots(3, 3, figsize=(10,10), layout="constrained")
    for idx, prop in enumerate(properties):
        for state in range(3):
    
            if prop == 'E':
                conv_fac = HARTREE_TO_meV
                targ_min = targ_dict[prop][:, 0].min()
                targ = targ_dict[prop][:, state].flatten() 
                targ -= targ_min
                targ *= conv_fac / targ_dict['F'].shape[1]
                pred = pred_dict[prop][:, state].flatten() 
                pred -= targ_min
                pred *= conv_fac / targ_dict['F'].shape[1]
            else:
                conv_fac = HARTREE_TO_meV / BOHR_TO_ANGSTROM
                targ = targ_dict[prop][:, :, state, :].flatten() * conv_fac
                pred = pred_dict[prop][:, :, state, :].flatten() * conv_fac
    
            min_max[prop][0] = min(min_max[prop][0], targ.min(), pred.min())
            min_max[prop][1] = max(min_max[prop][1], targ.max(), pred.max())
    
            # get stats
            pearson_r, p = scipy.stats.pearsonr(targ, pred)
            R2 = r2_score(targ, pred)
            mae = abs(pred-targ).mean()
            rmse = np.sqrt(np.power(pred-targ, 2).mean())
    
            ax = axs[idx, state]
            # do plotting
            hb = ax.hexbin(targ, pred,
                               cmap='Blues',
                               gridsize=25,
                               mincnt=1,
                               bins="log",
                               edgecolors=None,
                               linewidths=(0.2,),
                               xscale="linear",
                               yscale="linear",
                               )
            
            ax.text(0.05, 0.9, 'MAE: %.2f' % (mae),
                   transform=ax.transAxes, fontsize=labelsize*0.9,
                   zorder=10)
            ax.text(0.05, 0.8, 'RMSE: %.2f' % (rmse),
                   transform=ax.transAxes, fontsize=labelsize*0.9,
                   zorder=10)
            ax.text(0.6, 0.2, r'$R^2$: %.2f' % (R2),
               transform=ax.transAxes, fontsize=labelsize*0.9,
               zorder=10)
            ax.text(0.6, 0.1, r'$\rho$: %.2f' % (pearson_r),
                   transform=ax.transAxes, fontsize=labelsize*0.9,
                   zorder=10)
    
        for jj,ax in enumerate(axs[idx]):
            if idx != 2:
                ax.set_xlim(min_max[prop]*1.1)
                ax.set_ylim(min_max[prop]*1.1)
                #ax.set_aspect('equal')
            ax.spines['bottom'].set_linewidth(tickwidth)
            ax.spines['top'].set_linewidth(tickwidth)
            ax.spines['left'].set_linewidth(tickwidth)
            ax.spines['right'].set_linewidth(tickwidth)
            ax.tick_params(axis='x', length=maj_tick_size, width=tickwidth,
                           labelsize=labelsize, pad=pad,
                           direction='in')
            if jj >0:
                ax.set_yticklabels([])
            ax.tick_params(axis='y', length=maj_tick_size, width=tickwidth,
                   labelsize=labelsize, pad=pad,
                   direction='in')
    
        # titles
        for idx, ax in enumerate(axs[0]):
            ax.set_title("$E_{S_%i}$" % idx, fontsize=titelsize, pad=2*pad)
        for idx, ax in enumerate(axs[1]):
            ax.set_title(r"$\partial E_{S_%i} / \partial \mathbf{R}$" % idx, fontsize=titelsize, pad=2*pad)
        for idx, ax in enumerate(axs[2]):
            ax.set_title("NAC"+nac_labels[idx], fontsize=titelsize, pad=2*pad)
    
        # axis labels
        for idx, ax in enumerate(axs[0]):
            ax.set_xlabel(f"Target [{units['E']}]", fontsize=labelsize)
            if idx==0:
                ax.set_ylabel(f"Prediction [{units['E']}]", fontsize=labelsize)
        for idx, ax in enumerate(axs[1]):
            ax.set_xlabel(f"Target [{units['F']}]", fontsize=labelsize)
            if idx==0:
                ax.set_ylabel(f"Prediction [{units['F']}]", fontsize=labelsize)
        for idx, ax in enumerate(axs[2]):
            ax.set_xlabel(f"Target [{units['NAC']}]", fontsize=labelsize)
            if idx==0:
                ax.set_ylabel(f"Prediction [{units['NAC']}]", fontsize=labelsize)
    
    plt.show()

#### SchNet

In [None]:
make_partity_plot(targ_schnet, pred_schnet)

#### PaiNN

In [None]:
make_partity_plot(targ_painn, pred_painn)