# Validation with model.predict

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from pinn import get_model
from pinn.io import load_ds, sparse_batch
## reference data from https://gitlab.com/matschreiner/Transition1x/-/blob/main/transition1x/dataloader.py
def get_ref(datum, unit=1):
    REF_ATOM = {
        1: -13.62222753701504,
        6: -1029.4130839658328,
        7: -1484.8710358098756,
        8: -2041.8396277138045,
        9: -2712.8213146878606,
    }
    ref = sum(REF_ATOM[e] for e in datum['elems'])
    return ref

## Getting labels and predictions

- label (`*_data`) are loaded as in `validate_ds.ipynb`
- the predictor requires the dataset to be a function (`lambda` expression)
- optionally, this can be done on a subset

In [None]:
model = get_model('../t1x_trial/benchmark/models/qm9-pinet-1/model/')
ds = lambda: load_ds('../t1x_trial/datasets/final.yml').apply(sparse_batch(1)).take(500)

In [None]:
e_ref = np.array([get_ref(d) for d in ds().as_numpy_iterator()])
e_data = np.array([d['e_data'].squeeze() for d in ds().as_numpy_iterator()])
f_data = np.concatenate([d['f_data'].flatten() for d in ds().as_numpy_iterator()])

In [None]:
predictor = model.predict(ds)
cnt = 0
e_pred=[]
f_pred=[]
f_pred_3d = []
for pred in predictor:
    cnt += 1
    if cnt%100==0: print(f'\r{cnt}', end='')
    e_pred.append(pred['energy'])
    f_pred.append(pred['forces'].flatten())
    f_pred_3d.append(pred['forces'])
    
e_pred = np.array(e_pred)
f_pred = np.concatenate(f_pred)

## Plotting energy and force errors

Create subplots in matplotlib with [plt.subplots]

[plt.subplots]: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html

In [None]:
f, axs = plt.subplots(1, 2, figsize=[8,4])

axs[0].scatter(e_data-e_ref, e_pred-e_ref)
axs[1].scatter(f_data, f_pred)

## Reporting numbers **check the units！**


In [None]:
f_rmse = np.mean((f_data - f_pred)**2)
e_rmse = np.mean((e_data - e_pred)**2)
print ( f'E RMSE {e_rmse} [UNIT];  F RMSE {f_rmse} [UNIT].' )

## Automat generation of numbers

```python
from glob import glob
all_models = glob('../t1x_trial/benchmark/models/*/model/')
for model = all_models:
    ...
```

## Export trajectories

Write data, add predicted forces as an extra column. This can be visualized with OVITO (map the force_pred to vectorialz properties such as dipole, and change the visual element accordingly)

In [None]:
from mock import patch
from ase.calculators.calculator import all_properties
from ase.io.extxyz import per_atom_properties

# patches for ASE IO modules
extra_properties = ['forces_pred']
all_prop_patch = patch("ase.io.extxyz.all_properties", all_properties + extra_properties)
atom_prop_patch = patch("ase.io.extxyz.per_atom_properties", per_atom_properties + extra_properties)
sp_patch = patch("ase.calculators.singlepoint.all_properties",  all_properties + extra_properties)

In [None]:
with sp_patch, atom_prop_patch, all_prop_patch:
    from ase.calculators.singlepoint import SinglePointCalculator
    from ase import Atoms
    from ase.io import write
    traj = []
    for d, f in zip(ds(), f_pred_3d):
        atoms = Atoms(d['elems'].numpy(), positions=d['coord'].numpy())
        calc = SinglePointCalculator(
            atoms, 
            forces=d['f_data'].numpy(),
            forces_pred=f)
        atoms.calc = calc
        traj.append(atoms)
    write('export.xyz', traj)