In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from glob import glob
from pao_utils import parse_pao_file, append_samples
import pandas as pd
import numpy as np

import torch
import torch.utils.data
import se3cnn
import livelossplot as llp

import sys, os
import random
import numpy as np

from se3cnn.utils import torch_default_dtype
import se3cnn.point_utils as point_utils
from se3cnn.non_linearities import NormSoftplus
from se3cnn.convolution import SE3PointConvolution
from se3cnn.blocks.point_norm_block import PointNormBlock 
from se3cnn.point_kernel import gaussian_radial_function
from se3cnn.SO3 import torch_default_dtype

from functools import partial

from numpy.linalg import norm
from scipy.optimize import linear_sum_assignment

torch.set_default_dtype(torch.float64)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [2]:
# Find and parse all .pao files.
# Each file corresponds to a molecular configuration, ie. a frame.
# Since the system contains multiple atoms, each .pao file contains multiple samples.
class PAODataset(object):
    def __init__(self, kind_name):
        self.kind_name = kind_name
        self.sample_coords = []
        self.sample_xblocks = []
        self.sample_iatoms = []

        #TODO split in training and test set
        for fn in glob("2H2O_MD/frame_*/2H2O_pao44-1_0.pao"):
            kinds, atom2kind, coords, xblocks = parse_pao_file(fn)
            for iatom, kind in enumerate(atom2kind):
                if kind != self.kind_name:
                    continue
                rel_coords = coords - coords[iatom,:] # relative coordinates
                self.sample_coords.append(rel_coords)
                self.sample_xblocks.append(xblocks[iatom])
                self.sample_iatoms.append(iatom)

        # assuming kinds and atom2kind are the same across whole training data
        kinds_enum = list(kinds.keys())
        self.kinds_onehot = np.zeros((len(kinds), len(atom2kind)))
        for iatom, kind in enumerate(atom2kind):
            idx = kinds_enum.index(kind)
            self.kinds_onehot[idx, iatom] = 1.0
        
    def __getitem__(self, idx):
        # roll central atom to the front
        iatom = self.sample_iatoms[idx]
        rolled_kinds = np.roll(self.kinds_onehot, shift=-iatom, axis=1)
        rolled_coords =  np.roll(self.sample_coords[idx], shift=-iatom, axis=0)  
        return rolled_kinds, rolled_coords, idx

    def __len__(self):
        return len(self.sample_xblocks)

In [3]:
class PAONet(torch.nn.Module):
    def __init__(self, num_kinds, pao_basis_size, prim_basis_shells, num_radial=4, max_radius=2.5):
        super().__init__()
        self.num_kinds = num_kinds
        self.prim_basis_shells = prim_basis_shells
        self.pao_basis_size = pao_basis_size
        
        features = []
        features.append([num_kinds, 0, 0])  # L=0 for atom type as one-hot encoding
        features.append([8, 8, 8]) # hidden layer with filters L=0,1,2
        features.append([8, 8, 8]) # hidden layer with filters L=0,1,2
        features.append([ i * pao_basis_size for i in prim_basis_shells])

        nonlinearity = lambda x: torch.log(0.5 * torch.exp(x) + 0.5)
        sigma = max_radius / num_radial
        radii = torch.linspace(0, max_radius, steps=num_radial, dtype=torch.float64)
        radial_function = partial(gaussian_radial_function, sigma=2*sigma)
        radii_args = {'radii': radii, 'radial_function': radial_function}

        # Convolutions with Norm nonlinearity layers
        self.layers = torch.nn.ModuleList()
        
        # input layer
        self.layers.append(PointNormBlock(features[0], features[1], activation=nonlinearity, **radii_args))
        
        # hidden layer
        self.layers.append(PointNormBlock(features[1], features[2], activation=nonlinearity, **radii_args))
        
        # output layer
        Rs_repr = lambda features: [(m, l) for l, m in enumerate(features)]
        self.layers.append(SE3PointConvolution(Rs_repr(features[2]), Rs_repr(features[3]), **radii_args))
                        
        
    def forward(self, input, difference_mat, relative_mask=None):
        output = input
        for layer in self.layers:
            output = layer(output, difference_mat, relative_mask)
        return output

        #TODO: this could make things a lot simpler:
        ## decode network's 1-D output into 2-D xblock with shape [num_pao, num_prim].
        #xblock = output.reshape(-1, self.pao_basis_size).transpose()
        #return xblock

In [4]:
def mirror(xblock):
    """ duplicate pao vectors with flipped sign """
    m, n = xblock.shape # size of pao and prim basis
    result = np.zeros((2*m, n))
    result[:m,:] = xblock
    result[m:,:] = -xblock
    return result

def align(xblock, ref_xblock):
    """ align xblock onto ref_xblock in-place """
    m, n = xblock.shape # size of pao and prim basis
    
    # We can treat sign-flips as permutations by including each basis vector with both signs.
    a = mirror(xblock)
    b = mirror(ref_xblock)
    
    # build distance matrix
    dist = np.zeros((2*m,2*m))
    for i in range(2*m):
        for j in range(2*m):
            dist[i,j] = norm(a[i,:] - b[j,:])
            
    # run Hungarian algorithm
    row_ind, col_ind = linear_sum_assignment(dist)

    # permute pao basis vectors in-place
    permutations = 0
    for i, j in enumerate(col_ind[:m]):
        permutations += int(i != j)
        xblock[i,:] = a[j,:]

    return permutations # number of permutations, should approach zero as training progresses

In [49]:
def encode_xblock(xblock, prim_basis_shells):
    """Encodes a [num_pao, num_prim] 2D-block into a 1-D array"""
    xvec = []
    i = 0
    for l, m in enumerate(prim_basis_shells):
        n = m * (2 * l + 1)
        xvec.append(xblock[:, i:i+n].flatten())
        i += n
    return np.concatenate(xvec)

def decode_xblock(xvec, num_pao, prim_basis_shells):
    """Decodes a 1-D array into a [num_pao, num_prim] 2-D block."""
    xblock = []
    i = 0
    for l, m in enumerate(prim_basis_shells):
        n = m * num_pao * (2 * l + 1)
        xblock.append(xvec[i:i+n].reshape(num_pao, m * (2 * l + 1)))
        i += n
    return np.concatenate(xblock, axis=1)

In [54]:
prim = ("s1", "s2", "p1x", "p1y", "p1z", "p2x", "p2y", "p2z", "d1xy", "d1yz", "d1zx", "d1xx", "d1zz")
xblock = np.array([["%s,%i"%(x,p) for x in prim ] for p in range(4)])

In [None]:
# assuming MOLOPT-DZVP as primary basis set
prim_basis_shells = {
    'H': [2, 1, 0], # two s-shells, one p-shell, no d-shells
    'O': [2, 2, 1], # two s-shells, two p-shells, one d-shell
}

net = PAONet(num_kinds=2, pao_basis_size=4, prim_basis_shells=prim_basis_shells['O'])
net.train()
dataset = PAODataset("O")
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)

for epoch in range(100):
    epoch_permutations = 0
    epoch_loss = 0
    for batch in dataloader:
        kind_onehot, coords, sample_indices = batch    
        diff_M = se3cnn.point_utils.difference_matrix(coords)

        # forward pass
        output_net = net(kind_onehot, diff_M)

        # Use Hungarian algorithm to align training data sample to network's output.

        output_sample = []
        for i, idx in enumerate(sample_indices):  # loop over batch
            # We only care about the xblock of the central atom, which we rolled to the front.
            xblock_enc_net = output_net[i,:,0]
            
            # decode xblock returned by network
            xblock_net = decode_xblock(xblock_enc_net.detach().numpy(), 4, prim_basis_shells['O'])
            
            # get xblock from training data
            xblock_sample = dataset.sample_xblocks[idx]
            
            # aligh sample xblock onto xblock outputed by the network
            epoch_permutations += align(xblock_sample, xblock_net)
            
            # encode aligned sample xblock
            output_sample.append(encode_xblock(xblock_sample, prim_basis_shells['O']))


        # Compute loss
        output_sample = torch.tensor(output_sample)
        loss = loss_fn(output_net[:,:,0], output_sample)
        epoch_loss += loss.item()
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %i Loss: %f Permutations: %i"%(epoch, epoch_loss, epoch_permutations))

Epoch: 0 Loss: 2.087192 Permutations: 437
Epoch: 1 Loss: 1.185242 Permutations: 136
Epoch: 2 Loss: 0.807268 Permutations: 92
Epoch: 3 Loss: 0.509907 Permutations: 53
Epoch: 4 Loss: 0.380037 Permutations: 24
Epoch: 5 Loss: 0.291385 Permutations: 12
Epoch: 6 Loss: 0.247453 Permutations: 14
Epoch: 7 Loss: 0.229705 Permutations: 10
Epoch: 8 Loss: 0.225291 Permutations: 12
Epoch: 9 Loss: 0.208763 Permutations: 14
Epoch: 10 Loss: 0.194447 Permutations: 12
Epoch: 11 Loss: 0.181160 Permutations: 12
Epoch: 12 Loss: 0.171960 Permutations: 12
Epoch: 13 Loss: 0.162915 Permutations: 8
Epoch: 14 Loss: 0.158637 Permutations: 8
