In [133]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [141]:
import numpy as np
from glob import glob
import itertools
from pao_file_utils import parse_pao_file
from sympy.physics.quantum.cg import CG

import tensorflow as tf
from tensorflow import keras

In [135]:
#https://github.com/cp2k/cp2k/blob/master/src/common/spherical_harmonics.F
def Y_l(r, l):
    """Real Spherical Harmonics"""
    assert r.shape[-1] == 3

    if l < 0:
        raise Exceptoin("Negative l value")
    elif l == 0:
        return np.sqrt(1.0 / (4.0 * np.pi))
    elif l == 1:
        pf = np.sqrt(3.0 / (4.0 * np.pi))
        return pf * r
    elif l == 2:
        x = r[..., 0]
        y = r[..., 1]
        z = r[..., 2]
        result = np.zeros(5)
        # m = 2
        pf = np.sqrt(15.0 / (16.0 * np.pi))
        result[0] = pf * x**2 - y**2
        # m = 1
        pf = np.sqrt(15.0 / (4.0 * np.pi))
        result[1] = pf * z * x
        # m = 0
        pf = np.sqrt(5.0 / (16.0 * np.pi))
        result[2] = pf * (3.0 * z**2 - 1.0)
        # m = -1
        pf = np.sqrt(15.0 / (4.0 * np.pi))
        result[3] = pf * z * y
        # m = -2
        pf = np.sqrt(15.0 / (16.0 * np.pi))
        result[4] = pf * 2.0 * x * y
    return result

In [186]:
def convolute(coords, kinds, central_atom, max_l):
    natoms = coords.shape[0]
    assert coords.shape[1] == 3

    integrals = []
    for l in range(max_l + 1):
        for sigma in [0.5, 1.0, 2.0, 3.0, 4.0]:
            for ikind in sorted(kinds):
                integrals.append(np.zeros(2*l + 1))
                for iatom in range(natoms):
                    if atom2kind[iatom] == ikind and iatom != central_atom:
                        r = coords[central_atom] - coords[iatom]
                        angular_part = Y_l(r, l)
                        radial_part = np.exp(- np.dot(r,r) / sigma**2)
                        integrals[-1] += radial_part * angular_part
    return integrals

In [187]:
cg_cache = dict()

def get_clebsch_gordan_coefficients(li, lj, lo):
    global cg_cache
    key = (li, lj, lo)
    if key not in cg_cache:
        assert abs(li-lj) <= lo <= abs(li+lj)
        coeffs = np.zeros(shape=(2*li+1, 2*lj+1, 2*lo+1))
        for mi in range(-li, li+1):
            for mj in range(-lj, lj+1):
                for mo in range(-lo, lo+1):
                    # https://docs.sympy.org/latest/modules/physics/quantum/cg.html
                    cg = CG(li, mi, lj, mj, lo, mo).doit()
                    coeffs[mi+li, mj+lj, mo+lo] = cg
        cg_cache[key] = coeffs
    return cg_cache[key]

In [188]:
def combinations(channels, max_l):
    """ returns all possbile combinations of input channels up to given max_l """
    output_channels = list()
    for channel_i, channel_j in itertools.combinations_with_replacement(channels, 2):
        assert len(channel_i.shape) == len(channel_j.shape) == 1
        li = (channel_i.size - 1) // 2
        lj = (channel_j.size - 1) // 2
        # There li + lj possible ways to combine the two channels.
        # We do all of them up to a max_l.
        lo_min = abs(li-lj)
        lo_max = min(li+lj, max_l)
        for lo in range(lo_min, lo_max + 1): # l of output
            channel_o = np.zeros(2*lo+1)
            cg = get_clebsch_gordan_coefficients(li, lj, lo)
            channel_o = np.einsum("i,j,ijo->o", channel_i, channel_j, cg)
            output_channels.append(channel_o)
    return output_channels

In [189]:
# Load training data and hard code metadata.
pao_files = sorted(glob("2H2O_MD/frame_*/2H2O_pao44-1_0.pao"))

prim_basis_shells = {
    'H': [2, 1, 0], # two s-shells, one p-shell, no d-shells
    'O': [2, 2, 1], # two s-shells, two p-shells, one d-shell
}

pao_basis_size = 4

In [213]:
class Sample:
    def __init__(self, channels, xblock):
        self.channels = channels
        self.xblock = xblock
        self.s_channels = np.stack([c for c in channels if c.size == 1])
        self.p_channels = np.stack([c for c in channels if c.size == 3])
        #self.d_channels = np.stack([c for c in channels if c.size == 5]) #TODO generalize
        #self.num_channels = (len(self.s_channels), len(self.p_channels), len(self.d_channels))

def build_dataset(kind_name, max_l):
    samples = []
    for fn in pao_files:
        kinds, atom2kind, coords, xblocks = parse_pao_file(fn)
        #kind_onehot = encode_kind(atom2kind)
        natoms = coords.shape[0]
        for iatom in range(natoms):
            if atom2kind[iatom] == kind_name:
                initial_channels = convolute(coords, kinds, iatom, max_l)
                comb_channels = combinations(initial_channels, max_l)
                sample = Sample(comb_channels, xblocks[iatom])
                samples.append(sample)

    print("samples: ", len(samples))
    print("s channels: ", samples[0].s_channels.shape[0])
    print("p channels: ", samples[0].p_channels.shape[0])
    #print("d channels: ", samples[0].d_channels.shape)
    return samples
    
H_dataset = build_dataset("H", max_l=1)
O_dataset = build_dataset("O", max_l=2)

samples:  324
s channels:  110
p channels:  155
samples:  162
s channels:  165
p channels:  310


In [209]:
np.stack(H_dataset[0].s_channels).shape

(110, 1)

In [216]:
def build_model(first_sample, pao_basis_size, prim_basis_shells):
    
    # define two sets of inputs
    s_input = keras.layers.Input(shape=first_sample.s_channels.shape)
    p_input = keras.layers.Input(shape=first_sample.p_channels.shape)


    p_input
    
    x = Dense(8)(s_input)
    
    model = keras.Model(inputs=[s_input, p_input], outputs=outputs, name='mnist_model')

# # the first branch operates on the first input
# x = Dense(8, activation="relu")(inputA)
# x = Dense(4, activation="relu")(x)
# x = Model(inputs=inputA, outputs=x)
 
# # the second branch opreates on the second input
# y = Dense(64, activation="relu")(inputB)
# y = Dense(32, activation="relu")(y)
# y = Dense(4, activation="relu")(y)
# y = Model(inputs=inputB, outputs=y)
 
# # combine the output of the two branches
# combined = concatenate([x.output, y.output])
 
# # apply a FC layer and then a regression prediction on the
# # combined outputs
# z = Dense(2, activation="relu")(combined)
# z = Dense(1, activation="linear")(z)
 
# # our model will accept the inputs of the two branches and
# # then output a single value
# model = Model(inputs=[x.input, y.input], outputs=z)
#     output_size = pao_basis_size * (prim_basis_shells[0] * num_channels[0] +
#                                     prim_basis_shells[1] * num_channels[1] +
#                                     prim_basis_shells[2] * num_channels[2])
    
#     model = keras.Sequential()
#     #model.add(keras.layers.Dense(10, input_shape=samples[0].inputs.shape))
#     #model.add(keras.layers.Dense(10)) # hidden layer
#     #model.add(keras.layers.Dense(10)) # hidden layer
    
#     # let's try a single layer
#     model.add(keras.layers.Dense(output_size, input_shape=samples[0].inputs.shape))

    
#     def assemble_xblock(x):
        

# #        # add a layer that returns the concatenation
# #         # of the positive part of the input and
# #         # the opposite of the negative part

# #         def antirectifier(x):
# #             x -= K.mean(x, axis=1, keepdims=True)
# #             x = K.l2_normalize(x, axis=1)
# #             pos = K.relu(x)
# #             neg = K.relu(-x)
# #             return K.concatenate([pos, neg], axis=1)

# #         model.add(Lambda(antirectifier))

#     model.compile(optimizer='adam',
#                   loss='binary_crossentropy',
#                   metrics=['accuracy', 'binary_crossentropy'])

#     model.summary()

H_model = build_model(H_dataset[0], pao_basis_size, prim_basis_shells['H'])
#O_model = build_model(O_dataset[0].num_channels, pao_basis_size, prim_basis_shells['O'])

In [202]:
model.predict(samples[0].inputs[None,:], batch_size=1)
#samples[0].inputs[None,:].shape

array([[-0.17002359,  0.54494166, -0.67605853, -0.53623134,  0.11146726,
         0.22992396,  0.22856215,  0.42034036, -0.26474568,  0.25754395]],
      dtype=float32)

In [172]:
samples[0].inputs.shape

(165,)