[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/klarh/geometric_algebra_attention/blob/master/examples/Structure%20identification%20using%20jax.ipynb)

In [None]:
%%sh
# Colab-specific setup that will be ignored elsewhere
if [ ! -z "$COLAB_GPU" ]; then
    pip install flowws-keras-geometry flowws-keras-experimental pyriodic-aflow freud-analysis
    pip install git+https://github.com/klarh/geometric_algebra_attention
fi

In [None]:
# More TPU and colab-specific setup
import os
if 'TPU_NAME' in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()

import pyriodic
for (n,) in pyriodic.db.query('select count(name) from unit_cells'):
    if n == 0:
        msg = """The colab import machinery sometimes makes pyriodic not get
        initialized correctly. Restarting the runtime should make it work properly."""
        raise RuntimeError(msg)

In [None]:
import flowws
from flowws import Argument as Arg
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental.stax import serial, Dense, Relu

from geometric_algebra_attention.jax import VectorAttention

def make_layernorm():
    def init(rng, input_shape):
        return input_shape, ()

    def eval_(params, x, rng=None):
        return jax.nn.normalize(x)

    return init, eval_

@flowws.add_stage_arguments
class CrystalStructureClassification(flowws.Stage):
    """Build a geometric attention network for the structure identification task.

    This module specifies the architecture of a network to classify
    local environments of crystal structures in a rotation-invariant
    manner.

    """

    ARGS = [
        Arg('rank', None, int, 2,
            help='Degree of correlations (n-vectors) to consider'),
        Arg('n_dim', '-n', int, 32,
            help='Working dimensionality of point representations'),
        Arg('dilation', None, float, 2,
            help='Working dimension dilation factor for MLP components'),
        Arg('merge_fun', '-m', str, 'mean',
            help='Method to merge point representations'),
        Arg('join_fun', '-j', str, 'mean',
            help='Method to join invariant and point representations'),
        Arg('n_blocks', '-b', int, 2,
            help='Number of deep blocks to use'),
        Arg('block_nonlinearity', None, bool, True,
            help='If True, add a nonlinearity to the end of each block'),
        Arg('residual', '-r', bool, True,
            help='If True, use residual connections within blocks'),
        Arg('invariant_mode', None, str, 'single',
            help='Attention mechanism rotation-invariant attribute mode'),
    ]

    def run(self, scope, storage):
        rank = self.arguments['rank']
        n_dim = self.arguments['n_dim']
        dilation_dim = int(np.round(n_dim*self.arguments['dilation']))
        merge_fun = self.arguments['merge_fun']
        join_fun = self.arguments['join_fun']
        invar_mode = self.arguments['invariant_mode']

        score = serial(
            Dense(dilation_dim),
            Relu,
            Dense(1)
            )

        value = serial(
            Dense(dilation_dim),
            make_layernorm(),
            Relu,
            Dense(n_dim)
            )

        def make_attention(reduce=False):
            attention = VectorAttention(
                score, value, reduce=reduce, rank=rank, merge_fun=merge_fun,
                join_fun=join_fun, invariant_mode=invar_mode).stax_functions
            return attention

        def init(rng, input_shape):
            (r_shape, v_shape) = input_shape

            def rngs_(rng):
                while True:
                    (next_rng, rng) = jax.random.split(rng)
                    yield next_rng
            rngs = rngs_(rng)

            def param(layer, sh):
              (last_shape, p) = layer[0](next(rngs), sh)
              params.append(p)
              return last_shape

            params = []
            last_shape = param(vscale, v_shape)
            for i, att in enumerate(attentions):
                last_shape = param(att, (r_shape, last_shape))
                if self.arguments['block_nonlinearity']:
                    last_shape = param(block_nonlins[i], last_shape)
            last_shape = param(final_attention, (r_shape, last_shape))
            last_shape = param(final_mlp, last_shape)

            return last_shape, params

        def eval_(params, x, rng=None):
            pstack = list(reversed(params))

            def run(layer, x):
                return layer[1](pstack.pop(), x)

            (r, v) = x
            last = run(vscale, v)
            for i, att in enumerate(attentions):
                residual_in = last
                last = run(att, (r, last))
                if self.arguments['block_nonlinearity']:
                    last = run(block_nonlins[i], last)
                if self.arguments['residual']:
                    last = residual_in + last
            last = run(final_attention, (r, last))
            last = run(final_mlp, last)
            return last

        vscale = Dense(n_dim)
        attentions = [make_attention() for _ in range(self.arguments['n_blocks'])]
        block_nonlins = []
        if self.arguments['block_nonlinearity']:
            block_nonlins = self.arguments['n_blocks']*[value]
        final_attention = make_attention(True)
        final_mlp = serial(
            Dense(dilation_dim), Relu, Dense(scope['num_classes']))

        scope['model_functions'] = init, eval_
        scope['loss'] = 'sparse_categorical_crossentropy'
        scope.setdefault('metrics', []).append('sparse_accuracy')

In [None]:
import functools

import flowws
from flowws import Argument as Arg
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jax.experimental.optimizers as optimizers

OPTIMIZERS = dict(
    adam=optimizers.adam,
    sgd=optimizers.sgd
)

OPTIMIZER_ARGS = dict(
    adam=[.005],
    sgd=[.001]
)

class Losses:
    @staticmethod
    def sparse_categorical_crossentropy(prediction, y):
        logp = prediction - logsumexp(prediction, axis=-1, keepdims=True)
        return -jnp.take_along_axis(logp, y[..., None], axis=-1)

    @staticmethod
    def sparse_accuracy(prediction, y):
        return jnp.argmax(prediction, axis=-1) == y

@flowws.add_stage_arguments
class Train(flowws.Stage):
    """Train a jax model.

    """

    ARGS = [
        Arg('optimizer', '-o', str, 'adam',
           help='optimizer to use'),
        Arg('epochs', '-e', int, 32,
           help='Max number of epochs'),
        Arg('batch_size', '-b', int, 256,
           help='Batch size'),
        Arg('validation_split', '-v', float, .3),
        Arg('seed', '-s', int, 13),
        Arg('verbose', None, bool, True,
            help='If True, print the training progress'),
    ]

    def run(self, scope, storage):
        x_train, y_train = scope['x_train'], scope['y_train']
        init_fun, eval_fun = scope['model_functions']

        validation_data = None
        if 'validation_data' in scope:
            validation_data = scope['validation_data']
        elif self.arguments['validation_split']:
            if isinstance(x_train, (list, tuple)):
                N = int(len(x_train[0])*self.arguments['validation_split'])
                splits = [np.split(piece, [N]) for piece in x_train]
                x_val = [piece[0] for piece in splits]
                x_train = [piece[1] for piece in splits]
            else:
                N = int(len(x_train)*self.arguments['validation_split'])
                x_val, x_train = np.split(x_train, [N])
            y_val, y_train = np.split(y_train, [N])
            validation_data = (x_val, y_val)

        opt = self.arguments['optimizer']
        (opt_init, opt_update, opt_params) = OPTIMIZERS[opt](*OPTIMIZER_ARGS[opt])

        if isinstance(x_train, (list, tuple)):
            x_shape = [v.shape for v in x_train]
        else:
            x_shape = x_train.shape

        params = init_fun(jax.random.PRNGKey(self.arguments['seed']), x_shape)[1]
        opt_state = opt_init(params)

        lossfun = getattr(Losses, scope['loss'])

        def loss(params, batch):
            (x, y) = batch
            prediction = eval_fun(params, x)
            return jnp.sum(jnp.mean(lossfun(prediction, y), axis=0))

        metric_names = scope.get('metrics', [])
        @jax.jit
        def metrics(params, batch):
            (x, y) = batch
            prediction = eval_fun(params, x)
            result = []
            for name in metric_names:
                result.append(jnp.mean(getattr(Losses, name)(prediction, y)))
            return jnp.array(result)

        @jax.jit
        def step(step, opt_state, batch):
            params = opt_params(opt_state)
            value, grads = jax.value_and_grad(loss)(params, batch)
            opt_state = opt_update(step, grads, opt_state)
            return value, opt_state, metrics(params, batch)

        @jax.jit
        def predict(params, x):
          return jax.nn.softmax(eval_fun(params, x))

        @jax.jit
        def evaluate_batch(params, batch):
            loss_val = loss(params, batch)
            metric_vals = metrics(params, batch)
            return jnp.concatenate([jnp.array([loss_val]), metric_vals])

        def evaluate(params, batches):
            return np.mean([evaluate_batch(params, batch_) for batch_ in batches], axis=0)

        def batchfun(xs, ys):
            batches = []
            N = len(xs[0]) if isinstance(xs, (list, tuple)) else len(xs)
            for i in range(0, N, self.arguments['batch_size']):
                batch = slice(i, i + self.arguments['batch_size'])
                if isinstance(xs, (list, tuple)):
                    x = [piece[batch] for piece in xs]
                else:
                    x = xs[batch]
                y = ys[batch]
                batches.append((x, y))
            return batches
        batches = batchfun(x_train, y_train)
        val_evaluate = functools.partial(evaluate, batches=batchfun(*validation_data))

        step_count = 0
        rng = np.random.default_rng(self.arguments['seed'])
        batch_indices = np.arange(len(batches))
        epoch_losses = []
        for epoch in range(self.arguments['epochs']):
            rng.shuffle(batch_indices)
            batch_losses = []
            for batch_index in batch_indices:
                (last_loss, opt_state, batch_metrics) = step(step_count, opt_state, batches[batch_index])
                batch_losses.append([last_loss] + list(batch_metrics))
                step_count += 1
            epoch_losses.append(np.mean(batch_losses, axis=0))
            if validation_data is not None and metric_names:
                val_evaluation = val_evaluate(opt_params(opt_state))
                epoch_losses[-1] = np.concatenate([epoch_losses[-1], val_evaluation])
            print(epoch, epoch_losses[-1])
            batch_losses.clear()

        scope['model'] = functools.partial(predict, opt_params(opt_state))
        scope['train_log'] = epoch_losses

In [None]:
import flowws
from flowws_keras_geometry.data import PyriodicDataset

w = flowws.Workflow(
    [
        PyriodicDataset(
            noise=[1e-3, 5e-2, 0.1],
            structures=[
                "hP2-Mg",
                "cI2-W",
                "cF4-Cu",
                "cF8-C",
                "cF8-SZn",
                "cP46-Si",
                "cF136-Si",
                "cP2-ClCs",
            ],
            size=2048,
            num_neighbors=12,
            test_fraction=0.2,
            seed=13,
        ),
        CrystalStructureClassification(),
        Train(epochs=32, batch_size=8),
    ]
)

scope = w.run()

In [None]:
import matplotlib.pyplot as pp

log = np.array(scope['train_log'])
pp.plot(log[:, 0], label='Train')
pp.plot(log[:, 2], label='Val')
pp.xlabel('Epoch'); pp.ylabel('Loss')
pp.legend()

pp.figure()
pp.plot(log[:, 1], label='Train')
pp.plot(log[:, 3], label='Val')
pp.xlabel('Epoch'); pp.ylabel('Accuracy')
pp.legend();