[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/klarh/geometric_algebra_attention/blob/master/examples/Molecular%20force%20regression%20using%20keras.ipynb)

In [None]:
%%sh
# Colab-specific setup that will be ignored elsewhere
if [ ! -z "$COLAB_GPU" ]; then
    pip install flowws-keras-geometry keras-gtar
    pip install --force-reinstall git+https://github.com/klarh/flowws-keras-experimental
    pip install git+https://github.com/klarh/geometric_algebra_attention
fi

In [None]:
import flowws
from flowws_keras_geometry.data import RMD17
from flowws_keras_experimental import InitializeTF, Train, Save
import geometric_algebra_attention.keras as gala

In [None]:
from flowws_keras_geometry.models.internal import GradientLayer, \
    NeighborhoodReduction, \
    PairwiseValueNormalization, PairwiseVectorDifference, \
    PairwiseVectorDifferenceSum

from geometric_algebra_attention.keras import MomentumNormalization, VectorAttention

import flowws
from flowws import Argument as Arg
import numpy as np
import tensorflow as tf
from tensorflow import keras

LAMBDA_ACTIVATIONS = {
    'log1pswish': lambda x: tf.math.log1p(tf.nn.swish(x)),
    'sin': tf.sin,
    'leakyswish': lambda x: tf.nn.swish(x) - 1e-2*tf.nn.swish(-x)
}

NORMALIZATION_LAYERS = {
    'layer': lambda _: keras.layers.LayerNormalization(),
    'layer_all': lambda rank: keras.layers.LayerNormalization(axis=[-i - 1 for i in range(rank + 1)]),
    'momentum': lambda _: MomentumNormalization(),
    'pairwise': lambda _: PairwiseValueNormalization(),
}

NORMALIZATION_LAYER_DOC = ' (any of {})'.format(','.join(NORMALIZATION_LAYERS))

@flowws.add_stage_arguments
class MoleculeForceRegression(flowws.Stage):
    """Build a geometric attention network for the molecular force regression task.

    This module specifies the architecture of a network to calculate
    atomic forces given the coordinates and types of atoms in a
    molecule. Conservative forces are computed by calculating the
    gradient of a scalar.

    """

    ARGS = [
        Arg('rank', None, int, 2,
            help='Degree of correlations (n-vectors) to consider'),
        Arg('n_dim', '-n', int, 32,
            help='Working dimensionality of point representations'),
        Arg('dilation', None, float, 2,
            help='Working dimension dilation factor for MLP components'),
        Arg('merge_fun', '-m', str, 'concat',
            help='Method to merge point representations'),
        Arg('join_fun', '-j', str, 'concat',
            help='Method to join invariant and point representations'),
        Arg('dropout', '-d', float, 0,
            help='Dropout rate to use, if any'),
        Arg('mlp_layers', None, int, 1,
            help='Number of hidden layers for score/value MLPs'),
        Arg('n_blocks', '-b', int, 2,
            help='Number of deep blocks to use'),
        Arg('block_nonlinearity', None, bool, True,
            help='If True, add a nonlinearity to the end of each block'),
        Arg('residual', '-r', bool, True,
            help='If True, use residual connections within blocks'),
        Arg('activation', '-a', str, 'swish',
            help='Activation function to use inside the network'),
        Arg('final_activation', None, str, 'swish',
            help='Final activation function to use within the network'),
        Arg('score_normalization', None, [str], [],
            help=('Normalizations to apply to score (attention) function' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('value_normalization', None, [str], [],
            help=('Normalizations to apply to value function' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('block_normalization', None, [str], [],
            help=('Normalizations to apply to the output of each attention block' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('invariant_value_normalization', None, [str], [],
            help=('Normalizations to apply to value function, before MLP layers' +
                  NORMALIZATION_LAYER_DOC)),
        Arg('invariant_mode', None, str, 'single',
           help='VectorAttention invariant_mode to use'),
        Arg('include_normalized_products', None, bool, False,
           help='Also include normalized geometric product terms'),
    ]

    def run(self, scope, storage):
        rank = self.arguments['rank']

        if self.arguments['activation'] in LAMBDA_ACTIVATIONS:
            activation_layer = lambda: keras.layers.Lambda(
                LAMBDA_ACTIVATIONS[self.arguments['activation']])
        else:
            activation_layer = lambda: keras.layers.Activation(
                self.arguments['activation'])

        if self.arguments['final_activation'] in LAMBDA_ACTIVATIONS:
            final_activation_layer = lambda: keras.layers.Lambda(
                LAMBDA_ACTIVATIONS[self.arguments['final_activation']])
        else:
            final_activation_layer = lambda: keras.layers.Activation(
                self.arguments['final_activation'])

        n_dim = self.arguments['n_dim']
        dilation_dim = int(np.round(n_dim*self.arguments['dilation']))

        def make_scorefun():
            layers = []

            for _ in range(self.arguments['mlp_layers']):
                layers.append(keras.layers.Dense(dilation_dim))

                for name in self.arguments['score_normalization']:
                    layers.append(NORMALIZATION_LAYERS[name](rank))

                layers.append(activation_layer())

                if self.arguments.get('dropout', 0):
                    layers.append(keras.layers.Dropout(self.arguments['dropout']))

            layers.append(keras.layers.Dense(1))
            return keras.models.Sequential(layers)

        def make_valuefun(in_network=True):
            layers = []

            if in_network:
                for name in self.arguments['invariant_value_normalization']:
                    layers.append(NORMALIZATION_LAYERS[name](rank))

            for _ in range(self.arguments['mlp_layers']):
                layers.append(keras.layers.Dense(dilation_dim))

                for name in self.arguments['value_normalization']:
                    layers.append(NORMALIZATION_LAYERS[name](rank))

                layers.append(activation_layer())

                if self.arguments.get('dropout', 0):
                    layers.append(keras.layers.Dropout(self.arguments['dropout']))

            layers.append(keras.layers.Dense(n_dim))
            return keras.models.Sequential(layers)

        def make_block(last):
            residual_in = last
            last = VectorAttention(
                make_scorefun(), make_valuefun(), False, rank=rank,
                join_fun=self.arguments['join_fun'],
                merge_fun=self.arguments['merge_fun'],
                invariant_mode=self.arguments['invariant_mode'],
                include_normalized_products=self.arguments['include_normalized_products'],
            )([delta_x, last])

            if self.arguments['block_nonlinearity']:
                last = make_valuefun(in_network=False)(last)

            if self.arguments['residual']:
                last = last + residual_in

            for name in self.arguments.get('block_normalization', []):
                last = NORMALIZATION_LAYERS[name](rank)(last)

            return last

        x_in = keras.layers.Input((scope['neighborhood_size'], 3))
        v_in = keras.layers.Input((scope['neighborhood_size'], scope['num_types']))

        delta_x = PairwiseVectorDifference()(x_in)
        delta_v = PairwiseVectorDifferenceSum()(v_in)

        last = keras.layers.Dense(n_dim)(delta_v)
        for _ in range(self.arguments['n_blocks']):
            last = make_block(last)

        (last, ivs, att) = VectorAttention(
            make_scorefun(), make_valuefun(), True, name='final_attention',
            rank=rank,
            join_fun=self.arguments['join_fun'],
            merge_fun=self.arguments['merge_fun'],
            invariant_mode=self.arguments['invariant_mode'],
            include_normalized_products=self.arguments['include_normalized_products'],
        )(
            [delta_x, last], return_invariants=True, return_attention=True)

        last = keras.layers.Dense(dilation_dim, name='final_mlp')(last)
        last = final_activation_layer()(last)
        last = NeighborhoodReduction()(last)
        last = keras.layers.Dense(1, name='energy_projection', use_bias=False)(last)
        last = GradientLayer()((last, x_in))

        scope['input_symbol'] = [x_in, v_in]
        scope['output'] = last
        scope['loss'] = 'mse'
        scope['attention_model'] = keras.models.Model([x_in, v_in], att)
        scope['invariant_model'] = keras.models.Model([x_in, v_in], ivs)

In [None]:
w = flowws.Workflow(
    [
        InitializeTF(),
        RMD17(
            seed=13,
            cache_dir="/tmp",
            n_train=1000,
            n_val=1000,
            y_scale_reduction=4,
            molecules=[
                "benzene",
            ],
        ),
        MoleculeForceRegression(
            n_dim=32,
            n_blocks=6,
            invariant_mode='full',
            block_normalization=['layer'],
            value_normalization=['layer'],
            include_normalized_products=True,
        ),
        Train(
            epochs=150,
            reduce_lr=25,
            early_stopping=70,
            batch_size=4,
            validation_split=0,
            reduce_lr_factor=0.8,
            early_stopping_best=1,
            disable_tqdm=1,
            optimizer='adam',
            accumulate_gradients=8,
            catch_keyboard_interrupt=True,
        ),
        Save(save_model=1),
    ],
    storage=flowws.DirectoryStorage("/tmp"),
)

scope = w.run()

In [None]:
import matplotlib.pyplot as pp

y = np.concatenate([series[1]['mean_absolute_error'] for series in scope['log_quantities']])
label = scope['workflow'].stages[2].arguments['invariant_mode']
pp.plot(y, label=label)
pp.gca().set_yscale('log')
pp.xlabel('Epoch'); pp.ylabel('MAE')
pp.figure()
y = np.concatenate([series[1]['val_mean_absolute_error'] for series in scope['log_quantities']])
label = scope['workflow'].stages[2].arguments['invariant_mode']
pp.plot(y, label=label)
pp.xlabel('Epoch'); pp.ylabel('Validation MAE')
pp.gca().set_yscale('log');