In [1]:
import numpy as np
import pandas as pd
import math

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [2]:
!nvidia-smi

Wed Aug 19 10:15:50 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 440.82       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Quadro GV100        Off  | 00000000:37:00.0 Off |                  Off |
| 32%   44C    P2    36W / 250W |   3011MiB / 32508MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|    0  

In [3]:
!pwd

/home/pstjohn/Research/20200608_redox_calculations/spin_gnn


In [4]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
import nfp

In [5]:
from tensorflow.keras import layers
from preprocess_inputs import preprocessor
preprocessor.from_json('tfrecords/preprocessor.json')

from loss import AtomInfMask, KLWithLogits

def parse_example(example):
    parsed = tf.io.parse_single_example(example, features={
        **preprocessor.tfrecord_features,
        **{'spin': tf.io.FixedLenFeature([], dtype=tf.string)}})

    # All of the array preprocessor features are serialized integer arrays
    for key, val in preprocessor.tfrecord_features.items():
        if val.dtype == tf.string:
            parsed[key] = tf.io.parse_tensor(
                parsed[key], out_type=preprocessor.output_types[key])
    
    # Pop out the prediction target from the stored dictionary as a seperate input
    parsed['spin'] = tf.io.parse_tensor(parsed['spin'], out_type=tf.float64)
    
    spin = parsed.pop('spin')
    
    return parsed, spin

max_atoms = 80
max_bonds = 100
batch_size = 128

# Here, we have to add the prediction target padding onto the input padding
padded_shapes = (preprocessor.padded_shapes(max_atoms=None, max_bonds=None), [None])

padding_values = (preprocessor.padding_values,
                  tf.constant(np.nan, dtype=tf.float64))

num_train = len(np.load('split.npz', allow_pickle=True)['train'])

train_dataset = tf.data.TFRecordDataset('tfrecords/train.tfrecord.gz', compression_type='GZIP')\
    .map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .cache().shuffle(buffer_size=num_train).repeat()\
    .padded_batch(batch_size=batch_size,
                  padded_shapes=padded_shapes,
                  padding_values=padding_values)\
    .prefetch(tf.data.experimental.AUTOTUNE)

In [6]:
class ConcatDense(layers.Layer):
    """ Layer to combine the concatenation and two dense layers """
    def build(self, input_shape):
        num_features = input_shape[0][-1]
        self.concat = layers.Concatenate()
        self.dense1 = layers.Dense(2 * num_features, activation='relu')
        self.dense2 = layers.Dense(num_features)        
        
    def call(self, inputs, mask=None):
        
        output = self.concat(inputs)
        output = self.dense1(output)
        output = self.dense2(output)
        return output

class GraphLayer(layers.Layer):
    """ Base class for all GNN layers """
    
    def build(self, input_shape):
        if len(input_shape) == 4:
            self.use_global = True
            self.tile = Tile()
            
        elif len(input_shape) == 3:
            self.use_global = False
            
        else:
            raise RuntimeError("wrong input shape")        
        
    
class EdgeUpdate(GraphLayer):
    def build(self, input_shape):
        """ inputs = [atom_state, bond_state, connectivity]
        shape(bond_state) = [batch, num_bonds, bond_features]
        """
        super(EdgeUpdate, self).build(input_shape)
        
        bond_features = input_shape[1][-1]
        
        self.gather = nfp.Gather()
        self.slice1 = nfp.Slice(np.s_[:, :, 1])
        self.slice0 = nfp.Slice(np.s_[:, :, 0])
        
        self.concat = ConcatDense()
        self.add = layers.Add()
        
    
    def call(self, inputs, mask=None):
        """ Inputs: [atom_state, bond_state, connectivity]
            Outputs: bond_state
        """
        if not self.use_global:
            atom_state, bond_state, connectivity = inputs
        else:
            atom_state, bond_state, connectivity, global_state = inputs
            global_state = self.tile([global_state, bond_state])
            
        # Get nodes at start and end of edge
        source_atom = self.gather([atom_state, self.slice1(connectivity)])
        target_atom = self.gather([atom_state, self.slice0(connectivity)])

        if not self.use_global:
            new_bond_state = self.concat([bond_state, source_atom, target_atom])
        else:
            new_bond_state = self.concat([bond_state, source_atom, target_atom, global_state])
            
        new_bond_state = self.add([bond_state, new_bond_state])        
        return new_bond_state
    
    def compute_output_shape(self, input_shape):
        return input_shape[1]
    
    
class NodeUpdate(GraphLayer):
    def build(self, input_shape):
        super(NodeUpdate, self).build(input_shape)
        
        num_features = input_shape[1][-1]
        
        self.gather = nfp.Gather()
        self.slice0 = nfp.Slice(np.s_[:, :, 0])        
        self.slice1 = nfp.Slice(np.s_[:, :, 1])

        self.concat = ConcatDense()
        self.reduce = nfp.Reduce(reduction='sum')
        
        self.dense1 = layers.Dense(2 * num_features, activation='relu')
        self.dense2 = layers.Dense(num_features)            
        self.add = layers.Add()
            
    def call(self, inputs, mask=None):
        """ Inputs: [atom_state, bond_state, connectivity]
            Outputs: atom_state
        """
        if not self.use_global:
            atom_state, bond_state, connectivity = inputs
        else:
            atom_state, bond_state, connectivity, global_state = inputs
            global_state = self.tile([global_state, bond_state])
                    
        source_atom = self.gather([atom_state, self.slice1(connectivity)])
        
        if not self.use_global:
            messages = self.concat([source_atom, bond_state])
        else:
            messages = self.concat([source_atom, bond_state, global_state])
            
        new_atom_state = self.reduce([messages, self.slice0(connectivity), atom_state])
        
        # Dense net after message reduction
        new_atom_state = self.dense1(new_atom_state)
        new_atom_state = self.dense2(new_atom_state)
        new_atom_state = self.add([atom_state, new_atom_state])
        
        return new_atom_state
    
    def compute_output_shape(self, input_shape):
        return input_shape[0]


class Tile(layers.Layer):    
    def call(self, inputs):
        global_state, target = inputs
        target_shape = tf.shape(target)[1]  # number of edges or nodes
        expanded = tf.expand_dims(global_state, 1)
        return tf.tile(expanded, tf.stack([1, target_shape, 1]))


class GlobalUpdate(GraphLayer):
    def __init__(self, units, num_heads, **kwargs):
        super(GlobalUpdate, self).__init__(**kwargs)
        self.units = units          # H
        self.num_heads = num_heads  # N
        
    def build(self, input_shape):
        super(GlobalUpdate, self).build(input_shape)        
        dense_units = self.units * self.num_heads  # N*H
        self.query_layer = layers.Dense(self.num_heads, name='query')
        self.value_layer = layers.Dense(dense_units, name='value')
        self.add = layers.Add()
        
    def transpose_scores(self, input_tensor):
        input_shape  = tf.shape(input_tensor)
        output_shape = [input_shape[0], input_shape[1], self.num_heads, self.units]
        output_tensor = tf.reshape(input_tensor, output_shape)
        return tf.transpose(a=output_tensor, perm=[0, 2, 1, 3])  # [B,N,S,H]
       
    def call(self, inputs, mask=None):
        
        if not self.use_global:
            atom_state, bond_state, connectivity = inputs
        else:
            atom_state, bond_state, connectivity, global_state = inputs
            
        batch_size = tf.shape(atom_state)[0]

        graph_elements = tf.concat([atom_state, bond_state], axis=1)
        query = self.query_layer(graph_elements)  # [B,N,S,H]
        query = tf.transpose(query, perm=[0, 2, 1])
        value = self.transpose_scores(self.value_layer(graph_elements))  # [B,N,S,H]
        
        attention_probs = tf.nn.softmax(query)
        context = tf.matmul(tf.expand_dims(attention_probs, 2), value)        
        context = tf.reshape(context, [batch_size, self.num_heads*self.units])        
        
        if self.use_global:
            global_state = self.add([global_state, context])
        else:
            global_state = context
            
        return global_state
    

class GraphBlock(GraphLayer):
    def __init__(self, units, num_heads, **kwargs):
        super(GraphBlock, self).__init__(**kwargs)
        self.units = units
        self.num_heads = num_heads
        
    def build(self, input_shape):
        super(GraphBlock, self).build(input_shape)
        self.layer_norm1 = layers.LayerNormalization()
        self.layer_norm2 = layers.LayerNormalization()
        
        if self.use_global:
            self.layer_norm3 = layers.LayerNormalization()
        
        self.edge_layer = EdgeUpdate()
        self.node_layer = NodeUpdate()
        self.global_layer = GlobalUpdate(self.units, self.num_heads)
        
    def call(self, inputs, mask=None):
        if not self.use_global:
            atom_state, bond_state, connectivity = inputs
            atom_state = self.layer_norm1(atom_state)
            bond_state = self.layer_norm2(bond_state)
            
            bond_state = self.edge_layer([atom_state, bond_state, connectivity])
            atom_state = self.node_layer([atom_state, bond_state, connectivity])
            global_state = self.global_layer([atom_state, bond_state, connectivity])            
            
        else:
            atom_state, bond_state, connectivity, global_state = inputs
            atom_state = self.layer_norm1(atom_state)
            bond_state = self.layer_norm2(bond_state)            
            global_state = self.layer_norm3(global_state)
            
            bond_state = self.edge_layer([atom_state, bond_state, connectivity, global_state])
            atom_state = self.node_layer([atom_state, bond_state, connectivity, global_state])
            global_state = self.global_layer([atom_state, bond_state, connectivity, global_state])
            
        return atom_state, bond_state, global_state
        

In [11]:
atom_features = 128
num_messages = 6

# Define keras model
n_atom = layers.Input(shape=[], dtype=tf.int64, name='n_atom')
atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity')

input_tensors = [atom_class, bond_class, connectivity, n_atom]

# Initialize the atom states
atom_state = layers.Embedding(preprocessor.atom_classes, atom_features,
                              name='atom_embedding', mask_zero=True)(atom_class)

# Initialize the bond states
bond_state = layers.Embedding(preprocessor.bond_classes, atom_features,
                              name='bond_embedding', mask_zero=True)(bond_class)


global_state = layers.Embedding(preprocessor.max_atoms, 1,
                                name='global_state')(n_atom)

def message_block(atom_state, bond_state, connectivity, global_state):

    atom_state = layers.LayerNormalization()(atom_state)
    bond_state = layers.LayerNormalization()(bond_state)
    global_state = layers.LayerNormalization()(global_state)
    
    bond_state = EdgeUpdate()([atom_state, bond_state, connectivity, global_state])
    atom_state = NodeUpdate()([atom_state, bond_state, connectivity, global_state])
    global_state = GlobalUpdate(16, 8)([atom_state, bond_state, connectivity, global_state])
    
    return atom_state, bond_state, global_state

for i in range(num_messages):
    atom_state, bond_state, global_state = message_block(atom_state, bond_state, connectivity, global_state)
    
atom_embedding_model = tf.keras.Model(input_tensors, atom_state, name='atom_embedding_model')



n_atom = layers.Input(shape=[], dtype=tf.int64, name='n_atom')
atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity')

input_tensors = [atom_class, bond_class, connectivity, n_atom]

atom_state = atom_embedding_model(input_tensors)

atom_mean = layers.Embedding(preprocessor.atom_classes, 1,
                             name='atom_mean', mask_zero=True)(atom_class)

atom_pred = layers.Dense(1)(atom_state)
atom_pred = layers.Add()([atom_pred, atom_mean])
atom_pred = AtomInfMask()(atom_pred)

model = tf.keras.Model(input_tensors, atom_pred)

learning_rate = tf.keras.optimizers.schedules.InverseTimeDecay(1E-4, 1, 1E-5)
model.compile(loss=KLWithLogits(), optimizer=tf.keras.optimizers.Adam(learning_rate))
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
atom (InputLayer)               [(None, None)]       0                                            
__________________________________________________________________________________________________
bond (InputLayer)               [(None, None)]       0                                            
__________________________________________________________________________________________________
connectivity (InputLayer)       [(None, None, 2)]    0                                            
__________________________________________________________________________________________________
n_atom (InputLayer)             [(None,)]            0                                            
____________________________________________________________________________________________

In [12]:
model.fit(train_dataset,
          steps_per_epoch=100,
          epochs=5,
          verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f98357d04d0>