In [None]:
%%bash
# Patch aseinterface.py to use GPU

if ! grep -q ".cpu()" /opt/conda/envs/T7/lib/python3.10/site-packages/spainn/interface/aseinterface.py
then 
   sudo sed -i 's/.detach()/.cpu().detach()/g' /opt/conda/envs/T7/lib/python3.10/site-packages/spainn/interface/aseinterface.py
fi

# 🧠 Notebook 2: Loading Pretrained ML Models — SchNet and PaiNN
## Making predictions for all geometries on the 2D grid

In this notebook, we load two pretrained machine learning models that have been trained to predict **multiple adiabatic potential energy surfaces** and **non-adiabatic couplings** for $\mathrm{CH_2NH_2^+}$. These models are based on two state-of-the-art neural network architectures: **SchNet** and **PaiNN**.

### 🔷 SchNet

**SchNet** is one of the earliest deep learning architectures tailored for atomistic systems. It uses continuous-filter convolutional layers to model interactions between atoms based on their distances, making it:

- **Invariant to rotations and translations**,
- **Suitable for scalar properties** like energies and forces,
- Trained directly on atomic positions and nuclear charges.

However, SchNet does **not** explicitly encode directional (vectorial) information, which can be a limitation when learning vector-valued quantities like **non-adiabatic couplings**.

### 🔷 PaiNN

**PaiNN** (Polarizable Atom Interaction Neural Network) builds on the SchNet paradigm but **adds equivariant features**: it includes both scalar and vector features and uses operations that are **equivariant under 3D rotations**. This makes PaiNN especially well-suited for:

- Predicting **vector-valued properties** (e.g., dipole moments, NACs),
- Preserving rotational symmetry while still encoding orientation-dependent interactions.

In other words: PaiNN knows how to “feel” directionality, while SchNet treats the system more like a set of interatomic distances.

---

### 🧰 What We’ll Do in This Notebook

- Load the pretrained SchNet and PaiNN models,
- Apply them to a set of molecular geometries,
- Extract predicted **energies** and **non-adiabatic couplings** for the lowest few singlet states.

The models were trained on the 4000 MRCI data points published in: Chem. Sci., 2019, 10, 8100, DOI: [10.1039/c9sc01742a](https:doi.org/10.1039/c9sc01742a)

This will allow us to compare how both models represent the electronic structure of $\mathrm{CH_2NH_2^+}$, and explore where their predictions agree—and where they might differ.


In [None]:
from ase.db import connect
from spainn.interface import NacCalculator
from schnetpack.transform import MatScipyNeighborList
import ase
import numpy as np
from tqdm import tqdm

### 🔧 SpaiNN: Machine-Learned Excited-State Dynamics with SHARC

In this notebook, we use **SpaiNN**, a specialized interface that adapts **PaiNN** models for use in nonadiabatic dynamics simulations. SpaiNN, like the SchNarc interface for SchNet models, is designed to bridge **machine learning models** with **excited-state dynamics tools**.

The core functionality of SpaiNN is to provide an **ASE-style calculator** that:

- Loads models trained with **schnetpack-style** infrastructure,
- Converts predictions (energies, forces, and non-adiabatic couplings) into standard **ASE `Atoms` properties**,
- Makes it possible to integrate ML-based predictions seamlessly into molecular dynamics workflows.

This setup enables full compatibility with ASE-based tools and, importantly, with **PySHARC**—the new Python-based interface for the **SHARC** (Surface Hopping including ARbitrary Couplings) dynamics engine.

Unlike the older SHARC versions, **PySHARC** avoids file-based communication wherever possible. This makes it especially suitable for coupling with **ML models**, where predictions are computed rapidly and do not rely on disk I/O. The SpaiNN interface is thus a key component for running **efficient, on-the-fly nonadiabatic dynamics** powered entirely by ML potentials.

In this notebook, we focus on **loading a pretrained SpaiNN model** and demonstrate how it can be used to make predictions in a form directly usable by ASE and PySHARC.

---

### Predictions for the PaiNN model

In [None]:
calc = NacCalculator(model_file="Painn_model/best_model", neighbor_list=MatScipyNeighborList(cutoff=10.0), device="cuda")
atom = ase.Atoms(symbols="CNHHHH")
atom.calc = calc

In [None]:
db = connect("grid_configuration.db")
pred_energies = np.zeros(shape=(9191, 3))
pred_nacs     = np.zeros(shape=(9191, 3))

In [None]:
for idx in tqdm(range(len(db))):
    atom.set_positions(db.get(idx+1).positions)
    # NOTE: forces and nacs have the shape (Natoms, Nstates, xyz) -> here (6, 3, 3)
    props = atom.get_properties(['energy', 'smooth_nacs'])
    pred_energies[idx] = props['energy']
    pred_nacs[idx] = np.sum(np.linalg.norm(props['smooth_nacs'], axis=2), axis=0)

In [None]:
# saving the predictions for later analysis
np.savez("Predictions_Painn.npz", energy = pred_energies, nacs = pred_nacs)

## Predictions for the SchNet model

In [None]:
calc = NacCalculator(model_file="Schnet_model/best_model", neighbor_list=MatScipyNeighborList(cutoff=10.0), device="cuda")
atom = ase.Atoms(symbols="CNHHHH")
atom.calc = calc

In [None]:
db = connect("grid_configuration.db")
pred_energies = np.zeros(shape=(9191, 3))
pred_nacs     = np.zeros(shape=(9191, 3))

In [None]:
for idx in tqdm(range(len(db))):
    atom.set_positions(db.get(idx+1).positions)
    # NOTE: forces and nacs have the shape (Natoms, Nstates, xyz) -> here (6, 3, 3)
    props = atom.get_properties(['energy', 'smooth_nacs'])
    pred_energies[idx] = props['energy']
    pred_nacs[idx] = np.sum(np.linalg.norm(props['smooth_nacs'], axis=2), axis=0)

In [None]:
# saving the predictions for later analysis
np.savez("Predictions_Schnet.npz", energy = pred_energies, nacs = pred_nacs)