In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
torch.set_default_dtype(torch.float64)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [None]:
from os import path
from glob import glob
import numpy as np
import matplotlib.pylab as plt

#from se3cnn.point_utils import difference_matrix
from ole_se3cnn import difference_matrix
from pao_file_utils import parse_pao_file, write_pao_file
from pao_tfn_dataset import encode_kind
from pao_tfn_trainer import train_pao_tfn, loss_function
from cp2k_file_utils import read_energy

In [None]:
# Load training data and hard code metadata.
pao_files = sorted(glob("2H2O_MD/frame_*/2H2O_pao44-1_0.pao"))

prim_basis_shells = {
    'H': [2, 1, 0], # two s-shells, one p-shell, no d-shells
    'O': [2, 2, 1], # two s-shells, two p-shells, one d-shell
}

pao_basis_size = 4

In [None]:
# Use the first few frames to train networks for Hydrogen and Oxygen.
train_params = dict(
    prim_basis_shells=prim_basis_shells,
    pao_basis_size=pao_basis_size,
    pao_files=pao_files[:10],  # number of frames used as training data
    num_hidden=1,              # number of hidden layers
    max_epochs=201,            # number of training epochs
)

net_H = train_pao_tfn(**train_params, kind_name="H")
net_O = train_pao_tfn(**train_params, kind_name="O")

In [None]:
# Use the trained networks to inferre xblocks for all frames and calc their loss.
losses = []

for fn in pao_files:
    kinds, atom2kind, coords, xblocks = parse_pao_file(fn)
    kind_onehot = encode_kind(atom2kind)
    natoms = coords.shape[0]
    losses.append(0.0)
    xblocks_inferred = []
    for iatom in range(natoms):
        kind_onehot_torch = torch.as_tensor(kind_onehot)[None,...]
        coords_torch = torch.as_tensor(coords)[None,...]
        diff_M = difference_matrix(coords_torch)
        if atom2kind[iatom] == "H":
            output_net = net_H(kind_onehot_torch, diff_M)
        elif atom2kind[iatom] == "O":
            output_net = net_O(kind_onehot_torch, diff_M)
        xblock_net = output_net[...,iatom]
        xblocks_inferred.append(xblock_net.detach().numpy()[0,...])
        loss = loss_function(xblock_net, torch.as_tensor(xblocks[iatom][None,...]))
        losses[-1] += loss.item() / natoms
    fn_inferred = path.dirname(fn) + "/inferred.pao"
    write_pao_file(fn_inferred, kinds, atom2kind, coords, xblocks_inferred)

plt.xlabel("Frame")
plt.ylabel("Loss")
plt.yscale('log')
plt.plot(losses);

In [None]:
# Run CP2K on a few of the inferred frames. DFT Energy difference below 1 milliHartree would be nice.
verify_frames = range(20)

for iframe in verify_frames:
    frame_dir = path.dirname(pao_files[iframe])
    ! cd $frame_dir; /opt/cp2k/exe/local/cp2k.ssmp -i 2H2O_pao44_inferred.inp > 2H2O_pao44_inferred.out
    pao_energy = read_energy(frame_dir+"/2H2O_pao44.out")
    pao_ml_energy = read_energy(frame_dir+"/2H2O_pao44_inferred.out")
    energy_diff = pao_ml_energy - pao_energy
    tmpl = "Frame: {} Dir: {} Loss: {:0.4e} Energy-diff: {:0.4e} Hartree"
    print(tmpl.format(iframe, frame_dir, losses[iframe], energy_diff))