Tensor Field Networks

Implementation of shape classification demonstration

In [1]:
%load_ext autoreload
%autoreload 2

In [28]:
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import tensorflow as tf
import random
from math import pi, sqrt
import tensorfieldnetworks.utils as utils

from tensorfieldnetworks.ShapeClassificationModel import ShapeClassificationModel

tetris = [[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
          [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2
          [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
          [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # T
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # zigzag
          [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)]]  # L

dataset = [np.array(points_, dtype='float32') for points_ in tetris]
num_classes = len(dataset)

In [76]:
model = ShapeClassificationModel(num_classes)
model(dataset[0])
model.summary()
model(dataset[1])

Model: "shape_classification_model_38"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
self_interaction_simple_79 ( multiple                  1         
_________________________________________________________________
input_layer_38 (InputLayer)  multiple                  0         
_________________________________________________________________
convolution_layer_56 (Convol multiple                  50        
_________________________________________________________________
concatenation_layer_56 (Conc multiple                  0         
_________________________________________________________________
self_interaction_layer_55 (S multiple                  8         
_________________________________________________________________
nonlinearity_layer_42 (Nonli multiple                  8         
_________________________________________________________________
convolution_layer_57 (Convol multiple

<tf.Tensor: shape=(8,), dtype=float32, numpy=
array([-0.01838526,  0.58214056,  0.64749813, -0.69602853,  0.11692503,
       -0.99306935,  0.1826635 ,  0.62261724], dtype=float32)>

In [80]:
model(dataset[7])

<tf.Tensor: shape=(8,), dtype=float32, numpy=
array([-0.07966256,  0.58856034,  0.62110114, -0.6583565 ,  0.03326574,
       -0.8877537 ,  0.17916514,  0.55663514], dtype=float32)>

In [93]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1.e-3)

<tf.Tensor: shape=(8,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

In [101]:
max_epochs = 2001
print_freq = 10

# training
for epoch in range(max_epochs):    
    loss_sum = 0.
    for label, shape in enumerate(dataset):
        with tf.GradientTape() as tape:
            pred = model(shape, training=True)
            truth = tf.one_hot(label, num_classes)
            loss = tf.nn.softmax_cross_entropy_with_logits(labels=truth, logits=pred)
            #print(loss)
            #optimizer.minimize(loss, model.trainable_variables)
            #loss_value, _ = sess.run([loss, train_op], feed_dict={r: shape, tf_label: label})
            loss_sum += loss
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, model.trainable_variables))
        
    if epoch % print_freq == 0:
        print("Epoch %d: validation loss = %.3f" % (epoch, loss_sum / num_classes))

Epoch 0: validation loss = 0.415
Epoch 10: validation loss = 0.415
Epoch 20: validation loss = 0.415


KeyboardInterrupt: 

In [87]:
model(dataset[0])

<tf.Tensor: shape=(8,), dtype=float32, numpy=
array([ 4.1171117 ,  3.805261  , -0.72957575, -4.2494464 ,  1.6895834 ,
       -8.492951  ,  1.5292494 , -1.8216822 ], dtype=float32)>

In [None]:
rng = np.random.RandomState()
test_set_size = 25
predictions = [list() for i in range(len(dataset))]

correct_predictions = 0
total_predictions = 0
for i in range(test_set_size):
    for label, shape in enumerate(dataset):
        rotation = utils.random_rotation_matrix(rng)
        rotated_shape = np.dot(shape, rotation)
        translation = np.expand_dims(np.random.uniform(low=-3., high=3., size=(3)), axis=0)
        translated_shape = rotated_shape + translation
        output_label = sess.run(tf.argmax(output), 
                                feed_dict={r: rotated_shape, tf_label: label})
        total_predictions += 1
        if output_label == label:
            correct_predictions += 1
print('Test accuracy: %f' % (float(correct_predictions) / total_predictions))