In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
torch.set_default_dtype(torch.float64)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [32]:
from os import path
from glob import glob
import numpy as np
from pao_tfn_trainer import train_pao_tfn
from se3cnn.point_utils import difference_matrix
from pao_file_utils import parse_pao_file, write_pao_file

In [43]:
# assuming MOLOPT-DZVP as primary basis set
prim_basis_shells = {
    'H': [2, 1, 0], # two s-shells, one p-shell, no d-shells
    'O': [2, 2, 1], # two s-shells, two p-shells, one d-shell
}
pao_basis_size = 4

pao_files = sorted(glob("2H2O_MD/frame_*/2H2O_pao44-1_0.pao"))
training_pao_files = pao_files[:1]
print(training_pao_files)

net_H = train_pao_tfn(training_pao_files, prim_basis_shells, pao_basis_size, kind_name="H", max_epochs=500)
net_O = train_pao_tfn(training_pao_files, prim_basis_shells, pao_basis_size, kind_name="O", max_epochs=500)

['2H2O_MD/frame_0000/2H2O_pao44-1_0.pao']
Training net for kind H using 4 smaples.
Epoch: 0  Missmatch: 7.010798  Penalty: 1.963789 Loss: 8.974587
Epoch: 10  Missmatch: 3.846207  Penalty: 1.832476 Loss: 5.678683
Epoch: 20  Missmatch: 2.282117  Penalty: 1.744603 Loss: 4.026720
Epoch: 30  Missmatch: 1.443754  Penalty: 1.698261 Loss: 3.142015
Epoch: 40  Missmatch: 1.237001  Penalty: 1.592752 Loss: 2.829754
Epoch: 50  Missmatch: 0.956868  Penalty: 1.607050 Loss: 2.563918
Epoch: 60  Missmatch: 0.710115  Penalty: 1.580735 Loss: 2.290850
Epoch: 70  Missmatch: 0.378234  Penalty: 1.490734 Loss: 1.868968
Epoch: 80  Missmatch: 0.253044  Penalty: 1.436786 Loss: 1.689830
Epoch: 90  Missmatch: 0.145495  Penalty: 1.335010 Loss: 1.480505
Epoch: 100  Missmatch: 0.205658  Penalty: 1.317555 Loss: 1.523213
Epoch: 110  Missmatch: 0.155560  Penalty: 1.237505 Loss: 1.393065
Epoch: 120  Missmatch: 0.158083  Penalty: 1.181922 Loss: 1.340004
Epoch: 130  Missmatch: 0.158004  Penalty: 1.037242 Loss: 1.195246
Epoc

In [44]:
for fn in pao_files:
    kinds, atom2kind, coords, xblocks = parse_pao_file(fn)
    kind_onehot = train_dataset.encode_kind(atom2kind)  # TODO: store kind-encoding in the net
    natoms = coords.shape[0]
    xblocks = []
    for iatom in range(natoms):
        rolled_kinds = np.roll(kind_onehot, shift=-iatom, axis=1)
        rolled_coords =  np.roll(coords, shift=-iatom, axis=0)
        rolled_kinds_torch = torch.from_numpy(rolled_kinds[None,...])
        rolled_coords_torch = torch.from_numpy(rolled_coords[None,...])
        diff_M = difference_matrix(rolled_coords_torch)
        if atom2kind[iatom] == "H":
            output_net = net_H(rolled_kinds_torch, diff_M)
            xblock_net = net_H.decode_xblock(output_net[0,:,0])
        elif atom2kind[iatom] == "O":
            output_net = net_O(rolled_kinds_torch, diff_M)
            xblock_net = net_O.decode_xblock(output_net[0,:,0])
        xblocks.append(xblock_net.detach().numpy())

    frame_dir = path.dirname(fn)
    fn_inferred = frame_dir + "/inferred.pao"
    write_pao_file(coords, xblocks, fn_inferred)
    print("Wrote "+fn_inferred)
    ! cd $frame_dir; /opt/cp2k/exe/local/cp2k.sopt -i 2H2O_pao44_inferred.inp > 2H2O_pao44_inferred.out
    break

Wrote 2H2O_MD/frame_0000/inferred.pao
