# Demo for PAO-ML via e3nn

See also https://docs.e3nn.org/en/latest/guide/convolution.html

## Requirements:
```
pip install --upgrade e3nn torch_cluster torch_scatter matplotlib
```

In [None]:
from pao_file_utils import parse_pao_file, write_pao_file, read_cp2k_energy
import torch
from e3nn import o3, nn
from e3nn.math import soft_one_hot_linspace
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import warnings
t = torch.tensor

In [None]:
# Convenient wrapper that returns torch Tensors
def parse_pao_file_torch(path: Path):
    kinds, atom2kind, coords, xblocks = parse_pao_file(path)
    return kinds, atom2kind, t(coords, dtype=torch.float32), [t(x, dtype=torch.float32) for x in xblocks]

In [None]:
# Load single training sample.
kinds, atom2kind, coords, xblocks = parse_pao_file_torch(Path("./2H2O_rotations/phi_00/2H2O_pao44-1_0.pao"))

In [None]:
TRAINING_KIND = "H" # The atom kinds for which we're training.
TRAINING_ATOMS = [i for i, kind in enumerate(atom2kind) if kind == TRAINING_KIND]
assert all(atom2kind[i]  == TRAINING_KIND for i in TRAINING_ATOMS)

In [None]:
# Irreps Input
irreps_input = o3.Irreps("2x0e") # features: is_hydrogen, is_oxygen
#irreps_input = o3.Irreps("6x0e") # feature: atom index as one hot !!! THIS IS A BIG HACK !!!

In [None]:
# Irreps Output
pao_basis_size = 4
prim_basis_specs = {
    "O": "2x0e + 2x1o + 1x2e", # DZVP-MOLOPT-GTH for Oxygen: two s-shells, two p-shells, one d-shell
    "H": "2x0e + 1x1o" # DZVP-MOLOPT-GTH for Hydrogen: two s-shells, one p-shell
}
prim_basis_spec = prim_basis_specs[TRAINING_KIND]
prim_basis_size = o3.Irreps(prim_basis_spec).dim
irreps_output = o3.Irreps(" + ".join(pao_basis_size*[prim_basis_spec]))
for iatom in TRAINING_ATOMS:
    assert irreps_output.dim == xblocks[iatom].flatten().size(0)

In [None]:
# Irreps Spherical Harmonics
irreps_sh = o3.Irreps.spherical_harmonics(lmax=irreps_output.lmax)

In [None]:
# Tensor Product
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)
    tp = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False)
print(tp.weight_numel)
tp.visualize()

In [None]:
# Perceptron
num_distances = 10
num_layers = 8
# Note ReLu does not work well because many of the distance buckets from soft_one_hot_linspace are zero.
fc = nn.FullyConnectedNet([num_distances, num_layers, tp.weight_numel], torch.sigmoid) # relu does not 
print("Number of parameters: ", sum(p.numel() for p in fc.parameters()))

In [None]:
# The input of each node is whether it's an oxygen or not.
num_neighbors = 6 # TODO: Remove the central atom, it doesn't carry any information.
assert coords.shape[0] == num_neighbors
f_in = t([(k=="H", k=="O") for k in atom2kind], dtype=torch.float32).mul(num_neighbors**0.5)

#f_in = torch.eye(num_neighbors) #  atom index as one hot #TODO: !!! THIS IS A BIG HACK  !!!
assert f_in.shape[0] == num_neighbors and f_in.shape[1] == irreps_input.dim

In [None]:
# CP2K uses the yzx convention, while e3nn uses xyz.
# https://docs.e3nn.org/en/stable/guide/change_of_basis.html
change_of_coord = torch.tensor([[0., 0., 1.],[1., 0., 0.],[0., 1., 0.]]) # yzx -> xyz
D = irreps_output.D_from_matrix(change_of_coord)

In [None]:
# Prepare model and loss function.
max_radius = 2

def loss_function(pred, label):
    p1 = pred.T @ pred
    p2 = label.T @ label # is a projector because labels are orthonormal
    return (p1 - p2).pow(2).sum()

def model(edge_vec):
    sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component')
    emb = soft_one_hot_linspace(edge_vec.norm(dim=1), 0.0, max_radius, num_distances,
                                basis='smooth_finite', cutoff=True).mul(num_distances**0.5)
    flat_xyz = tp(f_in, sh, fc(emb)).sum(dim=0).div(num_neighbors**0.5)
    flat_yzx = flat_xyz @ D
    return flat_yzx.reshape(pao_basis_size, prim_basis_size)

In [None]:
# Prepare features and labels.
def labelfy(xblock):
    # The loss_functions requires orthonormal labels.
    U, S, Vh = torch.linalg.svd(xblock, full_matrices=False)
    return Vh

labels = [labelfy(xblocks[i]) for i in TRAINING_ATOMS]
edge_vecs = [coords - coords[i] for i in TRAINING_ATOMS]
#edge_vecs[0][2,2] += 0.5 #HACK!!!!


In [None]:
optim = torch.optim.Adam(fc.parameters())
num_steps= 15001
for step in range(num_steps):
    optim.zero_grad()
    loss_values = ""
    for edge_vec, label in zip(edge_vecs, labels):
        pred = model(edge_vec)
        loss = loss_function(pred, label)
        #loss = (pred - label).pow(2).sum()
        loss.backward()
        loss_values += (f"  {loss:.8e}")
    if step % 1000 == 0:
        print(f"training {step:5d} | loss {loss_values}")
    optim.step()

# Validation

In [None]:
# Test against rotated training samples using loss functions
for path in sorted(Path().glob("2H2O_rotations/rand_*/2H2O_pao44-1_0.pao"))[:10]:
    _, _, test_coords, test_xblocks = parse_pao_file_torch(path)
    for i in TRAINING_ATOMS:
        edge_vec = test_coords - test_coords[i]
        test_loss = loss_function(model(edge_vec), labelfy(test_xblocks[i]))
        print(f"{path}: atom: {i} lost: {test_loss:e}")

In [None]:
# Test against randomly rotated training samples using CP2K.
for path in sorted(Path().glob("2H2O_rotations/rand_*/2H2O_pao44-1_0.pao")):
    _, _, sample_coords, sample_xblocks = parse_pao_file_torch(path)
    pred_xblocks = sample_xblocks.copy()
    for i in TRAINING_ATOMS:
        edge_vec = sample_coords - sample_coords[i]
        pred_xblocks[i] = model(edge_vec)
    write_pao_file(path.parent / "2H2O_pao44_eval.pao", kinds, atom2kind, sample_coords, pred_xblocks)
    ! cd {path.parent}; OMP_NUM_THREADS=8 ~/git/cp2k/exe/local/cp2k.sdbg 2H2O_pao44_eval.inp > 2H2O_pao44_eval.out
    test_energy = read_cp2k_energy(path.parent / "2H2O_pao44_eval.out")
    ref_energy = read_cp2k_energy(path.parent / "2H2O_pao44.out")
    rel_diff_energy = (test_energy - ref_energy) / ref_energy
    print(f"{path}: Relative Energy Diff: {rel_diff_energy:e}")