# Extremely simple dense conv3d demo

This is a very simple model which demonstrates usage of the conv3d layer on dense tensor data. A single input to the network consists of a 3x3x3x1 block of data, with exactly 1 element set to 1.0 and the rest set to 0.0. The network produces 3 outputs, which should be the [x, y, z] position of that single element. Note that the block has a trailing dimension of length 1 because tensorflow really wants you to have a color channel dimension. 

In [105]:
import numpy as np
import tensorflow as tf

In [106]:
# Copied from the tensorflow tutorial in class
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

def init_bias(shape):
    return tf.Variable(tf.zeros(shape))

In [110]:
def training_batch(batch_size):
    """
    Generate batch_size arrays of size 3x3x3x1, with exactly 
    one nonzero element, and the labels consisting of the 
    [x, y, z] indices of that element
    """
    data_batch = np.zeros((batch_size, 3, 3, 3, 1))
    labels_batch = np.zeros((batch_size, 3))
    for i in range(batch_size):
        x = np.random.randint(0, 3)
        y = np.random.randint(0, 3)
        z = np.random.randint(0, 3)
        data_batch[i, x, y, z, 0] = 1.0
        labels_batch[i, :] = [x, y, z]
    return data_batch, labels_batch

def test_set():
    """
    Generate all possible one-hot 3x3x3x1 arrays
    and their corresponding labels
    """
    data = []
    labels = []
    for x in range(3):
        for y in range(3):
            for z in range(3):
                block = np.zeros((3,3,3,1))
                block[x, y, z, 0] = 1.0
                data.append(block)
                label = [x, y, z]
                labels.append(label)
    return np.stack(data), np.stack(labels)

In [111]:
length = 3
nvals = 1
input_tensor = tf.placeholder(tf.float32, [None, length, length, length, 1])

conv1 = tf.nn.conv3d(input_tensor,
                     init_weights([3,3,3,1,1]),
                     strides=[1,1,1,1,1],
                     padding='SAME')
conv1 = tf.nn.relu(tf.nn.bias_add(conv1,
                                  init_bias(1)))
fc = tf.reshape(conv1, [-1, 3*3*3])
fc = tf.add(tf.matmul(fc, init_weights([3*3*3, 18])), init_bias(18))
fc = tf.nn.relu(fc)

out = tf.add(tf.matmul(fc, init_weights([18,3])), init_bias(3))

y = tf.placeholder(tf.float32, [None, 3])

loss = tf.nn.l2_loss(tf.sub(out, y))
train_optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_pred = tf.equal(tf.round(out), y)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.initialize_all_variables()

In [112]:
training_steps = 1000
batch_size = 100

sess = tf.Session()
# with tf.Session() as sess:
sess.run(init)

for step in range(training_steps):
    data_batch, labels_batch = training_batch(batch_size)
    sess.run(train_optimizer, feed_dict={input_tensor: data_batch,
                                         y: labels_batch})

data_test, labels_test = test_set()
acc = sess.run(accuracy, feed_dict={input_tensor: data_test,
                              y: labels_test})
print("acc:", acc)
print(sess.run(out, feed_dict={input_tensor: data_test[1:2]}))

acc: 1.0
[[ 0.05053852 -0.05796198  0.97312009]]


# Checking the answer


In [114]:
assert np.all(sess.run(tf.round(out), feed_dict={input_tensor: data_test}) == labels_test)

In [115]:
print(sess.run(tf.round(out), feed_dict={input_tensor: data_test}))

[[ 0.  0.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  2.]
 [ 0.  1.  0.]
 [ 0.  1.  1.]
 [ 0.  1.  2.]
 [ 0.  2.  0.]
 [ 0.  2.  1.]
 [ 0.  2.  2.]
 [ 1.  0.  0.]
 [ 1.  0.  1.]
 [ 1.  0.  2.]
 [ 1.  1.  0.]
 [ 1.  1.  1.]
 [ 1.  1.  2.]
 [ 1.  2.  0.]
 [ 1.  2.  1.]
 [ 1.  2.  2.]
 [ 2.  0.  0.]
 [ 2.  0.  1.]
 [ 2.  0.  2.]
 [ 2.  1.  0.]
 [ 2.  1.  1.]
 [ 2.  1.  2.]
 [ 2.  2.  0.]
 [ 2.  2.  1.]
 [ 2.  2.  2.]]


In [116]:
labels_test

array([[0, 0, 0],
       [0, 0, 1],
       [0, 0, 2],
       [0, 1, 0],
       [0, 1, 1],
       [0, 1, 2],
       [0, 2, 0],
       [0, 2, 1],
       [0, 2, 2],
       [1, 0, 0],
       [1, 0, 1],
       [1, 0, 2],
       [1, 1, 0],
       [1, 1, 1],
       [1, 1, 2],
       [1, 2, 0],
       [1, 2, 1],
       [1, 2, 2],
       [2, 0, 0],
       [2, 0, 1],
       [2, 0, 2],
       [2, 1, 0],
       [2, 1, 1],
       [2, 1, 2],
       [2, 2, 0],
       [2, 2, 1],
       [2, 2, 2]])

In [61]:
sess.close()