Tensor Field Networks

Implementation of missing point experiment

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import division
import random
import numpy as np
import tensorflow as tf
import tensorfieldnetworks.layers as layers
import tensorfieldnetworks.utils as utils
from tensorfieldnetworks.utils import EPSILON, FLOAT_TYPE

In [3]:
tf.reset_default_graph()

In [4]:
training_set_size = 1000
from ase.db import connect
with connect('qm9.db') as conn:
    qm9_coords = []
    qm9_atoms = []
    qm9_test_coords = []
    qm9_test_atoms = []
    qm9_test_23_coords = []
    qm9_test_23_atoms = []
    qm9_test_29_coords = []
    qm9_test_29_atoms = []
    for atoms in conn.select('4<natoms<=18', limit=training_set_size):
        qm9_coords.append(atoms.positions)
        qm9_atoms.append(atoms.numbers)
    for atoms in conn.select('natoms=19', limit=training_set_size):
        qm9_test_coords.append(atoms.positions)
        qm9_test_atoms.append(atoms.numbers)
    for atoms in conn.select('natoms=23', limit=training_set_size):
        qm9_test_23_coords.append(atoms.positions)
        qm9_test_23_atoms.append(atoms.numbers)
    for atoms in conn.select('24<natoms<=29', limit=training_set_size):
        qm9_test_29_coords.append(atoms.positions)
        qm9_test_29_atoms.append(atoms.numbers)

In [5]:
atom_order = list(set(np.concatenate(qm9_atoms)))
num_atom_types = len(atom_order)

In [6]:
def atom_type_to_one_hot(atom_numbers, atom_order):
    one_hot_dict = {atom_type: [1 if i == j else 0 for i in range(len(atom_order))]
                    for j, atom_type in enumerate(atom_order)}
    return list(map(lambda x: one_hot_dict[x], atom_numbers))

In [7]:
qm9_one_hot = list(map(lambda x: atom_type_to_one_hot(x, atom_order), qm9_atoms))
qm9_test_one_hot = list(map(lambda x: atom_type_to_one_hot(x, atom_order), qm9_test_atoms))
qm9_test_23_one_hot = list(map(lambda x: atom_type_to_one_hot(x, atom_order), qm9_test_23_atoms))
qm9_test_29_one_hot = list(map(lambda x: atom_type_to_one_hot(x, atom_order), qm9_test_29_atoms))

In [8]:
# radial basis functions
rbf_low = 0.
rbf_high = 2.5
rbf_count = 4
rbf_spacing = (rbf_high - rbf_low) / rbf_count
centers = tf.cast(tf.lin_space(rbf_low, rbf_high, rbf_count), FLOAT_TYPE)

In [9]:
# [N, 3]
r = tf.placeholder(FLOAT_TYPE, shape=(None, 3))

# [N, num_types]
one_hot = tf.placeholder(FLOAT_TYPE, shape=(None, num_atom_types))

# [N, N, 3]
rij = utils.difference_matrix(r)
    
# [N, N, 3]
unit_vectors = rij / tf.expand_dims(tf.norm(rij, axis=-1) + EPSILON, axis=-1)

dij = utils.distance_matrix(r)

# rbf : [N, N, rbf_count]
gamma = 1. / rbf_spacing
rbf = tf.exp(-gamma * tf.square(tf.expand_dims(dij, axis=-1) - centers))

layer_dims = [15, 15, 15, 1]

# EMBEDDING
# [N, layer1_dim, 1]
with tf.variable_scope(None, 'embed', values=[one_hot]):
    embed = layers.self_interaction_layer_with_biases(tf.reshape(one_hot, [-1, num_atom_types, 1]), layer_dims[0])
    input_tensor_list = {0: [embed]}

# LAYERS 1-3
num_layers = len(layer_dims) - 1
for layer in range(num_layers):
    layer_dim = layer_dims[layer + 1]
    with tf.variable_scope(None, 'layer' + str(layer), values=[input_tensor_list]):
        input_tensor_list = layers.convolution(input_tensor_list, rbf, unit_vectors)
        input_tensor_list = layers.concatenation(input_tensor_list)
        if layer == num_layers - 1:
            with tf.variable_scope(None, 'atom_types', values=[input_tensor_list[0]]):
                atom_type_list = layers.self_interaction({0: input_tensor_list[0]}, num_atom_types)
        input_tensor_list = layers.self_interaction(input_tensor_list, layer_dim)
        if layer < num_layers - 1:
            with tf.variable_scope(None, 'nonlinearity', values=[input_tensor_list]):
                input_tensor_list = layers.nonlinearity(input_tensor_list, nonlin=utils.ssp)

probabilty_scalars = input_tensor_list[0][0]
missing_coordinates = input_tensor_list[1][0]
atom_type_scalars = atom_type_list[0][0]

# [N]
p = tf.nn.softmax(tf.squeeze(probabilty_scalars))

# [N, 3], when layer3_dim == 1
output = tf.squeeze(missing_coordinates)

# votes : [N, 3]
votes = r + output

# guess : [3]
guess_coord = tf.tensordot(p, votes, [[0], [0]])
# guess_coord = tf.einsum('a,ai->i', p, votes)
guess_atom = tf.tensordot(p, tf.squeeze(atom_type_scalars), [[0], [0]])
# guess_atom = tf.einsum('a,ai->i', p, tf.squeeze(atom_type_scalars))

# missing_point [3]
missing_point = tf.placeholder(FLOAT_TYPE, shape=(3))
missing_atom_type = tf.placeholder(FLOAT_TYPE, shape=(num_atom_types))

# loss : []
loss = tf.nn.l2_loss(missing_point - guess_coord) 
loss += tf.nn.l2_loss(missing_atom_type - guess_atom)

In [10]:
sess = tf.Session()
# sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, "miniteacup/experiments/paper_tmp/qm9_model_50.ckpt")

INFO:tensorflow:Restoring parameters from miniteacup/experiments/paper_tmp/qm9_model_50.ckpt


In [11]:
guesses = []
for shape, types in zip(qm9_coords, qm9_one_hot):
    if len(shape) < 3:
            # Shape stuff fails with shape length of 2 -- skipping for now
        continue
    for remove_index in range(len(shape)):
        new_shape = np.delete(shape, remove_index, 0)
        new_types = np.delete(types, remove_index, 0)
        removed_point = shape[remove_index]
        removed_types = types[remove_index]
        #embedding = np.array([1 for _ in range(len(new_shape))])
        loss_value, guess_point, guess_type, votes_points, probs = sess.run(
            [loss, guess_coord, guess_atom, votes, p], 
            feed_dict={r: new_shape,
                       missing_point: removed_point,
                       missing_atom_type: removed_types,
                       one_hot: new_types})
        guesses.append([new_shape, removed_point, removed_types, loss_value, 
                        guess_point, guess_type, votes_points, probs])

In [12]:
test_guesses = []
for shape, types in zip(qm9_test_coords, qm9_test_one_hot):
    for remove_index in range(len(shape)):
        new_shape = np.delete(shape, remove_index, 0)
        new_types = np.delete(types, remove_index, 0)
        removed_point = shape[remove_index]
        removed_types = types[remove_index]
        loss_value, guess_point, guess_type, votes_points, probs = sess.run(
            [loss, guess_coord, guess_atom, votes, p], 
            feed_dict={r: new_shape,
                       missing_point: removed_point,
                       missing_atom_type: removed_types,
                       one_hot: new_types})
        test_guesses.append([new_shape, removed_point, removed_types, loss_value, 
                        guess_point, guess_type, votes_points, probs])

In [13]:
test_23_guesses = []
for shape, types in zip(qm9_test_23_coords, qm9_test_23_one_hot):
    for remove_index in range(len(shape)):
        new_shape = np.delete(shape, remove_index, 0)
        new_types = np.delete(types, remove_index, 0)
        removed_point = shape[remove_index]
        removed_types = types[remove_index]
        loss_value, guess_point, guess_type, votes_points, probs = sess.run(
            [loss, guess_coord, guess_atom, votes, p], 
            feed_dict={r: new_shape,
                       missing_point: removed_point,
                       missing_atom_type: removed_types,
                       one_hot: new_types})
        test_23_guesses.append([new_shape, removed_point, removed_types, loss_value, 
                        guess_point, guess_type, votes_points, probs])

In [14]:
test_29_guesses = []
for shape, types in zip(qm9_test_29_coords, qm9_test_29_one_hot):
    for remove_index in range(len(shape)):
        new_shape = np.delete(shape, remove_index, 0)
        new_types = np.delete(types, remove_index, 0)
        removed_point = shape[remove_index]
        removed_types = types[remove_index]
        #embedding = np.array([1 for _ in range(len(new_shape))])
        loss_value, guess_point, guess_type, votes_points, probs = sess.run(
            [loss, guess_coord, guess_atom, votes, p], 
            feed_dict={r: new_shape,
                       missing_point: removed_point,
                       missing_atom_type: removed_types,
                       one_hot: new_types})
        test_29_guesses.append([new_shape, removed_point, removed_types, loss_value, 
                        guess_point, guess_type, votes_points, probs])

In [15]:
sort_guesses = list(sorted(guesses, key=lambda x: -x[3]))
sort_test_guesses = sorted(test_guesses, key=lambda x: -x[3])
sort_test_23_guesses = sorted(test_23_guesses, key=lambda x: -x[3])
sort_test_29_guesses = sorted(test_29_guesses, key=lambda x: -x[3])

In [16]:
# number of predictions
print(len(sort_guesses))
print(len(sort_test_guesses))
print(len(sort_test_23_guesses))
print(len(sort_test_29_guesses))

15863
19000
23000
25356


In [17]:
# This should be the same as what's output during training for validation
print("train", np.sqrt(2 * np.sum(np.array(sort_guesses)[:,3]) / len(sort_guesses)))
print("test19", np.sqrt(2 * np.sum(np.array(sort_test_guesses)[:,3]) / len(sort_test_guesses)))
print("test23", np.sqrt(2 * np.sum(np.array(sort_test_23_guesses)[:,3]) / len(sort_test_23_guesses)))
print("test29", np.sqrt(2 * np.sum(np.array(sort_test_29_guesses)[:,3]) / len(sort_test_29_guesses)))

('train', 0.54773882108609684)
('test19', 0.49300219749530261)
('test23', 0.64658839612859753)
('test29', 1.1776781400643486)


In [18]:
sort_test_18_dist = np.linalg.norm(
    np.vstack(np.array(guesses)[:,1].tolist()) -\
    np.vstack(np.array(guesses)[:,4].tolist()), axis=-1)
sort_test_19_dist = np.linalg.norm(
    np.vstack(np.array(sort_test_guesses)[:,1].tolist()) -\
    np.vstack(np.array(sort_test_guesses)[:,4].tolist()), axis=-1)
sort_test_23_dist = np.linalg.norm(
    np.vstack(np.array(sort_test_23_guesses)[:,1].tolist()) -\
    np.vstack(np.array(sort_test_23_guesses)[:,4].tolist()), axis=-1)
sort_test_29_dist = np.linalg.norm(
    np.vstack(np.array(sort_test_29_guesses)[:,1].tolist()) -\
    np.vstack(np.array(sort_test_29_guesses)[:,4].tolist()), axis=-1)

In [19]:
# True or False for correct atom type
test_18_atom_type = np.equal(np.argmax(np.vstack(np.array(sort_guesses)[:,2].tolist()), axis=-1),
                             np.argmax(np.vstack(np.array(sort_guesses)[:,5].tolist()), axis=-1))
test_19_atom_type = np.equal(np.argmax(np.vstack(np.array(sort_test_guesses)[:,2].tolist()), axis=-1),
                             np.argmax(np.vstack(np.array(sort_test_guesses)[:,5].tolist()), axis=-1))
test_23_atom_type = np.equal(np.argmax(np.vstack(np.array(sort_test_23_guesses)[:,2].tolist()), axis=-1),
                             np.argmax(np.vstack(np.array(sort_test_23_guesses)[:,5].tolist()), axis=-1))
test_29_atom_type = np.equal(np.argmax(np.vstack(np.array(sort_test_29_guesses)[:,2].tolist()), axis=-1),
                             np.argmax(np.vstack(np.array(sort_test_29_guesses)[:,5].tolist()), axis=-1))

In [20]:
onehot_to_number = lambda x: atom_order[x]
atoms_18 = list(map(lambda x: onehot_to_number(x), 
                    list(map(lambda x: np.argmax(x), np.array(sort_guesses)[:,2].tolist())))) 
atoms_19 = list(map(lambda x: onehot_to_number(x), 
                    list(map(lambda x: np.argmax(x), np.array(sort_test_guesses)[:,2].tolist())))) 
atoms_23 = list(map(lambda x: onehot_to_number(x), 
                    list(map(lambda x: np.argmax(x), np.array(sort_test_23_guesses)[:,2].tolist())))) 
atoms_29 = list(map(lambda x: onehot_to_number(x), 
                    list(map(lambda x: np.argmax(x), np.array(sort_test_29_guesses)[:,2].tolist())))) 

In [21]:
# Accuracy by atom
acc_dist = 0.5
for atom_int, atom_name in zip([1, 6, 7, 8, 9], ['H', 'C', 'N', 'O', 'F']):
    print(atom_name)
    len_18 = len(list(filter(lambda x: x == atom_int, atoms_18)))
    len_19 = len(list(filter(lambda x: x == atom_int, atoms_19)))
    len_23 = len(list(filter(lambda x: x == atom_int, atoms_23)))
    len_29 = len(list(filter(lambda x: x == atom_int, atoms_29))) 
    if len_18 > 0:
        print("5-18", "%.1f" % (len(list(
            filter(lambda x: x[0] < acc_dist and x[1] and x[2] == atom_int, 
                   zip(sort_test_18_dist, test_18_atom_type, atoms_18)))) /\
                   len_18 * 100),
             len_18)
    else:
        print(None)
    if len_19 > 0:
        print("19", "%.1f" % (len(list(
            filter(lambda x: x[0] < acc_dist and x[1] and x[2] == atom_int, 
                   zip(sort_test_19_dist, test_19_atom_type, atoms_19)))) /\
                   len_19 * 100),
              len_19)
    else:
        print(None)
    if len_23 > 0:
        print("23", "%.1f" % (len(list(
            filter(lambda x: x[0] < acc_dist and x[1] and x[2] == atom_int, 
                   zip(sort_test_23_dist, test_23_atom_type, atoms_23)))) /\
                   len_23 * 100), 
              len_23)
    else:
        print(None)
    if len_29 > 0:
        print("24-29", "%.1f" % (len(list(
            filter(lambda x: x[0] < acc_dist and x[1] and x[2] == atom_int, 
                   zip(sort_test_29_dist, test_29_atom_type, atoms_29)))) /\
                   len_29 * 100), 
              len_29)
    else:
        print(None)

H
('5-18', '90.2', 7207)
('19', '91.9', 10088)
('23', '86.9', 14005)
('24-29', '27.6', 16362)
C
('5-18', '90.4', 5663)
('19', '99.6', 6751)
('23', '87.4', 7901)
('24-29', '45.4', 8251)
N
('5-18', '37.0', 1407)
('19', '15.7', 616)
('23', '0.0', 37)
('24-29', '0.0', 16)
O
('5-18', '15.7', 1536)
('19', '26.1', 1539)
('23', '38.2', 1057)
('24-29', '36.7', 727)
F
('5-18', '0.0', 50)
('19', '0.0', 6)
None
None


In [22]:
# Accuracy over all predictions
print(len(list(
    filter(lambda x: x[0] < 0.5 and x[1], 
           zip(sort_test_18_dist, test_18_atom_type)))) / len(list(sort_test_18_dist)))
print(len(list(
    filter(lambda x: x[0] < 0.5 and x[1], 
           zip(sort_test_19_dist, test_19_atom_type)))) / len(list(sort_test_19_dist)))
print(len(list(
    filter(lambda x: x[0] < 0.5 and x[1], 
           zip(sort_test_23_dist, test_23_atom_type)))) / len(list(sort_test_23_dist)))
print(len(list(
    filter(lambda x: x[0] < 0.5 and x[1], 
           zip(sort_test_29_dist, test_29_atom_type)))) / len(list(sort_test_29_dist)))

0.780621572212
0.868
0.846956521739
0.336172897933


In [23]:
# MAE by atom
acc_dist = 0.5
for atom_int, atom_name in zip([1, 6, 7, 8, 9], ['H', 'C', 'N', 'O', 'F']):
    print(atom_name)
    len_18 = len(list(filter(lambda x: x == atom_int, atoms_18)))
    len_19 = len(list(filter(lambda x: x == atom_int, atoms_19)))
    len_23 = len(list(filter(lambda x: x == atom_int, atoms_23)))
    len_29 = len(list(filter(lambda x: x == atom_int, atoms_29))) 
    if len_18 > 0:
        print("5-18", "%.2f" % np.mean(np.array(list(filter(lambda x: x[1] == atom_int,
                                                   list(zip(sort_test_18_dist,
                                                            atoms_18)))))[:,0]))
    else:
        print(None)
    if len_19 > 0:
        print("19", "%.2f" % np.mean(np.array(list(filter(lambda x: x[1] == atom_int,
                                                 list(zip(sort_test_19_dist,
                                                          atoms_19)))))[:,0]))
    else:
        print(None)
    if len_23 > 0:
        print("23", "%.2f" % np.mean(np.array(list(filter(lambda x: x[1] == atom_int,
                                                 list(zip(sort_test_23_dist,
                                                          atoms_23)))))[:,0]))
    else:
        print(None)
    if len_29 > 0:
        print("24-29", "%.2f" % np.mean(np.array(list(filter(lambda x: x[1] == atom_int,
                                                    list(zip(sort_test_29_dist,
                                                             atoms_29)))))[:,0]))
    else:
        print(None)

H
('5-18', '0.25')
('19', '0.24')
('23', '0.25')
('24-29', '0.33')
C
('5-18', '0.24')
('19', '0.18')
('23', '0.32')
('24-29', '0.46')
N
('5-18', '0.24')
('19', '0.28')
('23', '0.36')
('24-29', '0.38')
O
('5-18', '0.24')
('19', '0.36')
('23', '0.47')
('24-29', '0.56')
F
('5-18', '0.23')
('19', '0.27')
None
None


In [24]:
# MAE for distance
print(np.mean(sort_test_18_dist))
print(np.mean(sort_test_19_dist))
print(np.mean(sort_test_23_dist))
print(np.mean(sort_test_29_dist))

0.244574690677
0.231942903935
0.282319850819
0.376985247931


In [25]:
# True or False for correct atom type
test_18_atom_type_vector = np.linalg.norm(np.vstack(np.array(sort_guesses)[:,2].tolist()) -\
    np.vstack(np.array(sort_guesses)[:,5].tolist()), axis=-1)
test_19_atom_type_vector = np.linalg.norm(np.vstack(np.array(sort_test_guesses)[:,2].tolist()) -\
    np.vstack(np.array(sort_test_guesses)[:,5].tolist()), axis=-1)
test_23_atom_type_vector = np.linalg.norm(np.vstack(np.array(sort_test_23_guesses)[:,2].tolist()) -\
    np.vstack(np.array(sort_test_23_guesses)[:,5].tolist()), axis=-1)
test_29_atom_type_vector = np.linalg.norm(np.vstack(np.array(sort_test_29_guesses)[:,2].tolist()) -\
    np.vstack(np.array(sort_test_29_guesses)[:,5].tolist()), axis=-1)

In [26]:
# Accuracy of atom type (binary)
print(float(np.count_nonzero(test_18_atom_type)) / test_18_atom_type.shape[0])
print(float(np.count_nonzero(test_19_atom_type)) / test_19_atom_type.shape[0])
print(float(np.count_nonzero(test_23_atom_type)) / test_23_atom_type.shape[0])
print(float(np.count_nonzero(test_29_atom_type)) / test_29_atom_type.shape[0])

0.864212317973
0.925210526316
0.914782608696
0.43011516012


In [27]:
# MAE atom type
print(np.mean(test_18_atom_type_vector))
print(np.mean(test_19_atom_type_vector))
print(np.mean(test_23_atom_type_vector))
print(np.mean(test_29_atom_type_vector))

0.342844342205
0.324000269223
0.526471918943
0.993965169418


In [28]:
# Accuracy by distance
print(len(list(filter(lambda x: x < 0.5, sort_test_18_dist))) / len(list(sort_test_18_dist)))
print(len(list(filter(lambda x: x < 0.5, sort_test_19_dist))) / len(list(sort_test_19_dist)))
print(len(list(filter(lambda x: x < 0.5, sort_test_23_dist))) / len(list(sort_test_23_dist)))
print(len(list(filter(lambda x: x < 0.5, sort_test_29_dist))) / len(list(sort_test_29_dist)))

0.903990417954
0.928789473684
0.926608695652
0.804148919388
