Tensor Field Networks

Implementation of shape classification demonstration

In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import tensorflow as tf
import random
from math import pi, sqrt
import tensorfieldnetworks.layers as layers
import tensorfieldnetworks.utils as utils
from tensorfieldnetworks.utils import FLOAT_TYPE


In [10]:
W = tf.Variable(tf.ones(shape=(2,2)), name="W")
b = tf.Variable(tf.zeros(shape=(2)), name="b")

@tf.function
def forward(x):
  return W * x + b

out_a = forward([1,0])
print(out_a)
tf.zeros(shape=(2))

tf.Tensor(
[[1. 0.]
 [1. 0.]], shape=(2, 2), dtype=float32)


<tf.Tensor: id=254, shape=(2,), dtype=float32, numpy=array([0., 0.], dtype=float32)>

In [None]:
@tf.function
def R(inputs, nonlin=tf.nn.relu, hidden_dim=None, output_dim=1, weights_initializer=None, biases_initializer=None):
    """ Computes: radial = b2 + w2 * nonlin(b1 + w1 * input) """

        if weights_initializer is None:
            weights_initializer = tf.initializers.GlorotNormal()
        if biases_initializer is None:
            biases_initializer = tf.initializers.Constant(0.)

        input_dim = inputs.get_shape()[-1]
        if hidden_dim is None:
            hidden_dim = input_dim

        w1 = tf.Variable(name='weights1', [hidden_dim, input_dim], dtype=FLOAT_TYPE,
                             initializer=weights_initializer)
        b1 = tf.Variable(name='biases1', [hidden_dim], dtype=FLOAT_TYPE, initializer=biases_initializer)

        w2 = tf.Variable(name='weights2', [output_dim, hidden_dim], dtype=FLOAT_TYPE,
                             initializer=weights_initializer)
        b2 = tf.Variable(name='biases2', [output_dim], dtype=FLOAT_TYPE, initializer=biases_initializer)

        hidden_layer = nonlin(b1 + tf.tensordot(inputs, w1, [[2], [1]]))
        radial = b2 + tf.tensordot(hidden_layer, w2, [[2], [1]])

        # [N, N, output_dim]
        return radial


@tf.function
def F_0(inputs, nonlin=tf.nn.relu, hidden_dim=None, output_dim=1):
    # [N, N, output_dim, 1]

    return tf.expand_dims(
        R(inputs, nonlin=nonlin, hidden_dim=hidden_dim, output_dim=output_dim,
          weights_initializer=weights_initializer, biases_initializer=biases_initializer),
        axis=-1)

    
@tf.function
def filter_0(layer_input, rbf_inputs, nonlin=tf.nn.relu, hidden_dim=None, output_dim=1):
    # [N, N, output_dim, 1]
    F_0_out = F_0(rbf_inputs, nonlin=nonlin, hidden_dim=hidden_dim, output_dim=output_dim)
    # [N, output_dim]
    input_dim = layer_input.get_shape().as_list()[-1]
    # Expand filter axis "j"
    cg = tf.expand_dims(tf.eye(input_dim), axis=-2)
    # L x 0 -> L
    return tf.einsum('ijk,abfj,bfk->afi', cg, F_0_out, layer_input)

@tf.function
def convolution(inputs, rbf, unit_vectors):
    for l, input_l in enu
    output_tensor_list = {0: [], 1: []}

    # loop over l
    for key in input_tensor_list:
        with tf.variable_scope(f"L{key}"):

            # loop over atom
            for i, tensor in enumerate(input_tensor_list[key]):
                output_dim = tensor.get_shape().as_list()[-2]  # number of channels

                # L x 0 -> L
                tensor_out = filter_0(tensor, rbf, output_dim=output_dim)
                m = 0 if tensor_out.get_shape().as_list()[-1] == 1 else 1
                output_tensor_list[m].append(tensor_out)

        return output_tensor_list



In [21]:
# radial basis functions
rbf_low = 0.0
rbf_high = 3.5
rbf_count = 4
rbf_spacing = (rbf_high - rbf_low) / rbf_count
centers = tf.cast(tf.linspace(rbf_low, rbf_high, rbf_count), FLOAT_TYPE)

In [35]:
# r : [N, 3]
r = tf.placeholder(FLOAT_TYPE, shape=(4, 3))

# rij : [N, N, 3]
rij = utils.difference_matrix(r)

# dij : [N, N]
dij = utils.distance_matrix(r)

# rbf : [N, N, rbf_count]
gamma = 1. / rbf_spacing
rbf = tf.exp(-gamma * tf.square(tf.expand_dims(dij, axis=-1) - centers))

layer_dims = [1, 4, 4, 4]
num_layers = len(layer_dims) - 1

# embed : [N, layer1_dim, 1]
with tf.variable_scope(None, "embed"):
    embed = layers.self_interaction_layer_without_biases(tf.ones(shape=(4, 1, 1)), layer_dims[0])

# should use tf.one_hot, but for only two kinds it's not needed
#                          O       H       H      O       H       H
input_tensor_list = {0: list([ tf.constant([[v]], tf.float32) for v in [1,0,0,1,0,0] ])}
                         
print(input_tensor_list)
output_tensor_list = layers.convolution(input_tensor_list, rbf, rij)


sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(output_tensor_list)
# for layer, layer_dim in enumerate(layer_dims[1:]):
#     with tf.variable_scope(None, 'layer' + str(layer), values=[input_tensor_list]):
#         input_tensor_list = layers.convolution(input_tensor_list, rbf, rij)
#         input_tensor_list = layers.concatenation(input_tensor_list)
#         input_tensor_list = layers.self_interaction(input_tensor_list, layer_dim)
#         input_tensor_list = layers.nonlinearity(input_tensor_list)

# tfn_scalars = input_tensor_list[0][0]
# tfn_output_shape = tfn_scalars.get_shape().as_list()
# tfn_output = tf.reduce_mean(tf.squeeze(tfn_scalars), axis=0)
# fully_connected_layer = tf.get_variable('fully_connected_weights', 
#                                         [tfn_output_shape[-2], len(dataset)], dtype=FLOAT_TYPE)
# output_biases = tf.get_variable('output_biases', [len(dataset)], dtype=FLOAT_TYPE)

# # output : [num_classes]
# output = tf.einsum('xy,x->y', fully_connected_layer, tfn_output) + output_biases

# tf_label = tf.placeholder(tf.int32)

# # truth : [num_classes]
# truth = tf.one_hot(tf_label, num_classes)

# # loss : []
# loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=truth, logits=output)

# optim = tf.train.AdamOptimizer(learning_rate=1.e-3)

# train_op = optim.minimize(loss)

{0: [<tf.Tensor 'Const_7:0' shape=(1, 1) dtype=float32>, <tf.Tensor 'Const_8:0' shape=(1, 1) dtype=float32>, <tf.Tensor 'Const_9:0' shape=(1, 1) dtype=float32>, <tf.Tensor 'Const_10:0' shape=(1, 1) dtype=float32>, <tf.Tensor 'Const_11:0' shape=(1, 1) dtype=float32>, <tf.Tensor 'Const_12:0' shape=(1, 1) dtype=float32>]}


ValueError: Variable convolution/L0/tensor_0/F0_to_L/F_0/radial_function/weights1 already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

  File "/workspace/tensorfieldnetworks/layers.py", line 23, in R
    initializer=weights_initializer)
  File "/workspace/tensorfieldnetworks/layers.py", line 68, in F_0
    weights_initializer=weights_initializer, biases_initializer=biases_initializer),
  File "/workspace/tensorfieldnetworks/layers.py", line 113, in filter_0
    weights_initializer=weights_initializer, biases_initializer=biases_initializer)


In [28]:
[[[1.]], [[2.]]]

[[[1.0]], [[2.0]]]

In [32]:
# parse training data:
from glob import glob
from pao_file_utils import parse_pao_file

class Geometry:
    def __init__(self, coords):
        self.coords = coords

class Sample:
    def __init__(self, geo, iatom, xblock):
        self.geo = geo
        self.iatom = iatom
        self.xblock = xblock

class Dataset:
    def __init__(self, samples, atom2kind, kinds):
        self.samples = samples
        self.atom2kind = atom2kind
        self.kinds = kinds

samples = []
pao_files = sorted(glob("2H2O_MD/frame_*/2H2O_pao44-1_0.pao"))
for fn in pao_files:
    kinds, atom2kind, coords, xblocks = parse_pao_file(fn)
    geo = Geometry(coords)
    for iatom, xblock in enumerate(xblocks):
        s = Sample(geo, iatom, xblock)
        samples.append(s)

dataset = Dataset(samples, atom2kind, kinds)  # Assuming kinds are the consistend across dataset.
print(len(dataset.samples))
print(dataset.atom2kind)

486
['O', 'H', 'H', 'O', 'H', 'H']


In [5]:
max_epochs = 2001
print_freq = 100

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# training
for epoch in range(max_epochs):    
    loss_sum = 0.
    for label, shape in enumerate(dataset):
        loss_value, _ = sess.run([loss, train_op], feed_dict={r: shape, tf_label: label})
        loss_sum += loss_value
        
    if epoch % print_freq == 0:
        print("Epoch %d: validation loss = %.3f" % (epoch, loss_sum / len(dataset)))

Epoch 0: validation loss = 2.115
Epoch 100: validation loss = 0.935
Epoch 200: validation loss = 0.090
Epoch 300: validation loss = 0.022
Epoch 400: validation loss = 0.011
Epoch 500: validation loss = 0.006
Epoch 600: validation loss = 0.004
Epoch 700: validation loss = 0.002
Epoch 800: validation loss = 0.000
Epoch 900: validation loss = 0.000
Epoch 1000: validation loss = 0.000
Epoch 1100: validation loss = 0.000
Epoch 1200: validation loss = 0.000
Epoch 1300: validation loss = 0.000
Epoch 1400: validation loss = 0.000
Epoch 1500: validation loss = 0.000
Epoch 1600: validation loss = 0.000
Epoch 1700: validation loss = 0.000
Epoch 1800: validation loss = 0.000
Epoch 1900: validation loss = 0.000
Epoch 2000: validation loss = 0.000


In [6]:
rng = np.random.RandomState()
test_set_size = 25
predictions = [list() for i in range(len(dataset))]

correct_predictions = 0
total_predictions = 0
for i in range(test_set_size):
    for label, shape in enumerate(dataset):
        rotation = utils.random_rotation_matrix(rng)
        rotated_shape = np.dot(shape, rotation)
        translation = np.expand_dims(np.random.uniform(low=-3., high=3., size=(3)), axis=0)
        translated_shape = rotated_shape + translation
        output_label = sess.run(tf.argmax(output), 
                                feed_dict={r: rotated_shape, tf_label: label})
        total_predictions += 1
        if output_label == label:
            correct_predictions += 1
print('Test accuracy: %f' % (float(correct_predictions) / total_predictions))

Test accuracy: 1.000000
