In [1]:
%load_ext autoreload
%autoreload 2

In [149]:
from glob import glob
from pao_utils import parse_pao_file, append_samples
import pandas as pd
import numpy as np

import torch
import torch.utils.data
import se3cnn
import livelossplot as llp

import sys, os
import random
import numpy as np

from se3cnn.utils import torch_default_dtype
import se3cnn.point_utils as point_utils
from se3cnn.non_linearities import NormSoftplus
from se3cnn.convolution import SE3PointConvolution
from se3cnn.blocks.point_norm_block import PointNormBlock 
from se3cnn.point_kernel import gaussian_radial_function
from se3cnn.SO3 import torch_default_dtype

from functools import partial

from numpy.linalg import norm
from scipy.optimize import linear_sum_assignment

torch.set_default_dtype(torch.float64)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [154]:
# Find and parse all .pao files.
# Each file corresponds to a molecular configuration, ie. a frame.
# Since the system contains multiple atoms, each .pao file contains multiple samples.
class PAODataset(object):
    def __init__(self, kind_name):
        self.kind_name = kind_name
        self.sample_iatoms = []
        self.sample_coords = []
        self.sample_xblocks = []
        self.sample_compl_projector = []
        

        #TODO split in training and test set
        for fn in glob("2H2O_MD/frame_*/2H2O_pao44-1_0.pao"):
            kinds, atom2kind, coords, xblocks = parse_pao_file(fn)
            for iatom, kind in enumerate(atom2kind):
                if kind != self.kind_name:
                    continue
                rel_coords = coords - coords[iatom,:] # relative coordinates
                self.sample_coords.append(rel_coords)
                self.sample_xblocks.append(xblocks[iatom])
                self.sample_iatoms.append(iatom)
                
                # orthonormalize xblock_sample's basis vectors (they deviate slightly)
                #TODO: add a regularization term for this.
                #
                #https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process#Alternatives
                V = torch.t(torch.from_numpy(xblocks[iatom]))  #TODO: use torch.tensor() instead?
                VV  = torch.matmul(torch.t(V), V)
                L = torch.cholesky(VV)
                L_inv = torch.inverse(L)
                U = torch.matmul(V, torch.t(L_inv))
                projector = torch.matmul(U, torch.t(U))
                identity = torch.eye(projector.shape[0])
                compl_projector = identity - projector
                self.sample_compl_projector.append(compl_projector)
            

        # assuming kinds and atom2kind are the same across whole training data
        kinds_enum = list(kinds.keys())
        self.kinds_onehot = np.zeros((len(kinds), len(atom2kind)))
        for iatom, kind in enumerate(atom2kind):
            idx = kinds_enum.index(kind)
            self.kinds_onehot[idx, iatom] = 1.0
        
    def __getitem__(self, idx):
        # roll central atom to the front
        iatom = self.sample_iatoms[idx]
        rolled_kinds = np.roll(self.kinds_onehot, shift=-iatom, axis=1)
        rolled_coords =  np.roll(self.sample_coords[idx], shift=-iatom, axis=0)  
        return rolled_kinds, rolled_coords, idx

    def __len__(self):
        return len(self.sample_xblocks)

In [155]:
class PAONet(torch.nn.Module):
    def __init__(self, num_kinds, pao_basis_size, prim_basis_shells, num_hidden=1, num_radial=4, max_radius=2.5):
        super().__init__()
        self.num_kinds = num_kinds
        self.prim_basis_shells = prim_basis_shells
        self.pao_basis_size = pao_basis_size
                
        nonlinearity = lambda x: torch.log(0.5 * torch.exp(x) + 0.5)
        sigma = max_radius / num_radial
        radii = torch.linspace(0, max_radius, steps=num_radial, dtype=torch.float64)
        radial_function = partial(gaussian_radial_function, sigma=2*sigma)
        radii_args = {'radii': radii, 'radial_function': radial_function}
        
        # Convolutions with Norm nonlinearity layers
        self.layers = torch.nn.ModuleList()
        
        # features
        input_features = [num_kinds, 0, 0]  # L=0 for atom type as one-hot encoding
        hidden_features = [8, 8, 8] # hidden layer with filters L=0,1,2
        output_features = [i * pao_basis_size for i in prim_basis_shells]

        # input layer
        self.layers.append(PointNormBlock(input_features, hidden_features, activation=nonlinearity, **radii_args))
       
        # hidden layer
        for _ in range(num_hidden):
            self.layers.append(PointNormBlock(hidden_features, hidden_features, activation=nonlinearity, **radii_args))
       
        # output layer
        Rs_repr = lambda features: [(m, l) for l, m in enumerate(features)]
        self.layers.append(SE3PointConvolution(Rs_repr(hidden_features), Rs_repr(output_features), **radii_args))
                        
        
    def forward(self, input, difference_mat, relative_mask=None):
        output = input
        for layer in self.layers:
            output = layer(output, difference_mat, relative_mask)
        #TODO: things could be much simpler if the network directly returned decoded 2-D xblocks
        return output
        

In [157]:
def encode_xblock(xblock, prim_basis_shells):
    """Encodes a [num_pao, num_prim] 2D-block into a 1-D array"""
    xvec = []
    i = 0
    for l, m in enumerate(prim_basis_shells):
        n = m * (2 * l + 1)
        xvec.append(xblock[:, i:i+n].flatten())
        i += n
    #return np.concatenate(xvec)
    return torch.cat(xvec)

def decode_xblock(xvec, num_pao, prim_basis_shells):
    """Decodes a 1-D array into a [num_pao, num_prim] 2-D block."""
    xblock = []
    i = 0
    for l, m in enumerate(prim_basis_shells):
        n = m * num_pao * (2 * l + 1)
        xblock.append(xvec[i:i+n].reshape(num_pao, m * (2 * l + 1)))
        i += n
    return torch.cat(xblock, dim=1)
    #return np.concatenate(xblock, axis=1)

In [158]:
# prim = ("s1", "s2", "p1x", "p1y", "p1z", "p2x", "p2y", "p2z", "d1xy", "d1yz", "d1zx", "d1xx", "d1zz")
# xblock = np.array([["%s,%i"%(x,p) for x in prim ] for p in range(4)])
# #print(xblock)
# print(decode_xblock(encode_xblock(xblock, [2, 2, 1]), 4, [2, 2, 1]))

In [None]:
# assuming MOLOPT-DZVP as primary basis set
prim_basis_shells = {
    'H': [2, 1, 0], # two s-shells, one p-shell, no d-shells
    'O': [2, 2, 1], # two s-shells, two p-shells, one d-shell
}

net = PAONet(num_kinds=2, pao_basis_size=4, prim_basis_shells=prim_basis_shells['O'], num_hidden=2)
net.train()
dataset = PAODataset("O")
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)

import torch.nn.functional

for epoch in range(100):
    epoch_missmatch = 0
    epoch_penalty = 0
    epoch_loss = 0
    for batch in dataloader:
        kind_onehot, coords, sample_indices = batch    
        diff_M = se3cnn.point_utils.difference_matrix(coords)

        # forward pass
        output_net = net(kind_onehot, diff_M)

        # Use Hungarian algorithm to align training data sample to network's output.

        missmatch = torch.tensor(0.0)
        penalty = torch.tensor(0.0)
        for i, idx in enumerate(sample_indices):  # loop over batch
            # We only care about the xblock of the central atom, which we rolled to the front.
            xblock_enc_net = output_net[i,:,0]
            
            # decode xblock returned by network
            #xblock_net = decode_xblock(xblock_enc_net.detach().numpy(), 4, prim_basis_shells['O'])
            xblock_net = decode_xblock(xblock_enc_net, 4, prim_basis_shells['O'])
            
            # get complementary projector from training data
            sample_compl_projector = dataset.sample_compl_projector[idx]
            
            # force spanning same space as training data
            residual = torch.matmul(compl_projector, torch.t(xblock_net))
            missmatch += torch.norm(residual)
            
            # force returning orthonormal vectors
            identity_pao = torch.eye(xblock_net.shape[0])
            non_orthonormality = identity_pao - torch.matmul(xblock_net, torch.t(xblock_net))
            penalty += 0.001 * torch.norm(non_orthonormality)
        
        loss = missmatch + penalty
        
        epoch_missmatch += missmatch.item()
        epoch_penalty += penalty.item()
        epoch_loss += loss.item()
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %i  Missmatch: %f  Penalty: %f  Loss: %f"%(epoch, epoch_missmatch, epoch_penalty, epoch_loss))

Epoch: 0  Missmatch: 3.097995  Penalty: 0.323813  Loss: 3.421809
Epoch: 1  Missmatch: 1.436323  Penalty: 0.323923  Loss: 1.760246
Epoch: 2  Missmatch: 1.182246  Penalty: 0.323959  Loss: 1.506204
Epoch: 3  Missmatch: 0.933909  Penalty: 0.323977  Loss: 1.257886
Epoch: 4  Missmatch: 0.743917  Penalty: 0.323988  Loss: 1.067905
Epoch: 5  Missmatch: 0.636164  Penalty: 0.323991  Loss: 0.960155
Epoch: 6  Missmatch: 0.431013  Penalty: 0.323996  Loss: 0.755010
Epoch: 7  Missmatch: 0.457076  Penalty: 0.323997  Loss: 0.781074
Epoch: 8  Missmatch: 0.466386  Penalty: 0.323997  Loss: 0.790384
Epoch: 9  Missmatch: 0.308929  Penalty: 0.323998  Loss: 0.632928
Epoch: 10  Missmatch: 0.313987  Penalty: 0.323999  Loss: 0.637986
Epoch: 11  Missmatch: 0.373679  Penalty: 0.323998  Loss: 0.697678
Epoch: 12  Missmatch: 0.234489  Penalty: 0.323999  Loss: 0.558489
Epoch: 13  Missmatch: 0.232583  Penalty: 0.323999  Loss: 0.556582
Epoch: 14  Missmatch: 0.208619  Penalty: 0.323999  Loss: 0.532619
Epoch: 15  Missmatch