In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import sys
import os
import os.path
import h5py
import tensorflow as tf
slim = tf.contrib.slim
losses = slim.losses
ops = slim.ops
import tfndlstm as ndlstm
import tfspecs as specs
from urllib2 import urlopen

Downoad the MNIST dataset in HDF5 format if necessary.

In [3]:
url = "http://www.tmbdev.net/ocrdata-hdf5/mnist.h5"
if not os.path.exists("mnist.h5"):
    data = urlopen(url).read()
    with open("mnist.h5", "wb") as stream:
        stream.write(data)

Read the dataset and transform it in standard TensorFlow image batch format (BHWD)

In [4]:
from collections import namedtuple
Dataset = namedtuple("Dataset", "images classes size")

def loadh5(fname, prefix=""):
    h5 = h5py.File(fname)
    images = array(h5[prefix+"images"], "f")
    images.shape = images.shape + (1,)
    labels = array(h5[prefix+"labels"], "i")
    del h5
    return Dataset(images, labels, len(images))
    
train = loadh5("mnist.h5")
test = loadh5("mnist.h5", "test_")

In [5]:
def info(dataset):
    images = dataset.images
    labels = dataset.classes
    print images.shape, amin(images), amax(images)
    print labels.shape, amin(labels), amax(labels)

info(train)
info(test)

(60000, 28, 28, 1) 0.0 1.0
(60000,) 0 9
(10000, 28, 28, 1) 0.0 1.0
(10000,) 0 9


This is the "long form" network definition of a simple convolutional network using TensorFlow/Slim.

In [6]:
def make_network(net):
    net = slim.conv2d(net, 32, 3)
    net = slim.max_pool2d(net, 2)
    net = slim.conv2d(net, 64, 3)
    net = slim.max_pool2d(net, 2)
    net = slim.flatten(net)
    net = slim.fully_connected(100)
    net = slim.fully_connected(10, activation_fn=None)
    return net

The `specs` language lets us express the same network more concisely. You can use `model.funcall` just like the `make_network` function above.

In [7]:
with specs.ops:
    model = Cr(32, 3) | Mp(2) | Cr(64, 3) | Mp(2) | Flat | Fr(100) | Fl(10)

This defines the standard TensorFlow framework for training a network. It's pretty boilerplate.

In [8]:
sess = tf.InteractiveSession()

inputs = tf.placeholder(tf.float32, [None, 28, 28, 1])
labels = tf.placeholder(tf.int32, [None])

outputs = model.funcall(inputs)

targets = tf.one_hot(labels, 10, 1.0, 0.0)
loss = tf.reduce_sum(tf.square(targets-tf.nn.sigmoid(outputs)))
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(loss)

errors = tf.not_equal(tf.argmax(outputs,1), tf.argmax(targets,1))
nerrors = tf.reduce_sum(tf.cast(errors, tf.float32))

sess.run(tf.initialize_all_variables())

In [9]:
def train_epoch(train=train, bs=20):
    total = 0
    count = 0
    for i in range(0, train.size, bs):
        batch_images = train.images[i:i+bs]
        batch_classes = train.classes[i:i+bs]
        feed_dict = {
            inputs: batch_images,
            labels: batch_classes,
        }
        k, _ = sess.run([nerrors, train_op], feed_dict=feed_dict)
        total += k
        count += len(batch_images)
    training_error = total * 1.0 / count
    return count, training_error

In [10]:
def evaluate(test=test):
    bs = 1000
    total = 0
    for i in range(0, test.size, bs):
        batch_images = test[0][i:i+bs]
        batch_classes = test[1][i:i+bs]
        feed_dict = {
            inputs: batch_images,
            labels: batch_classes,
        }
        k, = sess.run([nerrors], feed_dict=feed_dict)
        total += k
    test_error = total * 1.0 / test.size
    return total, test_error

Now train for 100 epochs and print the training and test set error.

In [11]:
for epoch in range(20):
    _, training_err = train_epoch()
    _, testing_err = evaluate()
    print epoch, training_err, testing_err

0 0.117633333333 0.0449
1 0.0346 0.0262
2 0.0233 0.0204
3 0.0178666666667 0.0169
4 0.0146 0.016
5 0.0125 0.0129
6 0.0105166666667 0.012
7 0.00928333333333 0.0113
8 0.00818333333333 0.0106
9 0.00725 0.0104
10 0.00651666666667 0.01
11 0.00586666666667 0.0103
12 0.00558333333333 0.0096
13 0.00515 0.0094
14 0.00473333333333 0.0097
15 0.00438333333333 0.0099
16 0.00408333333333 0.0098
17 0.00395 0.0098
18 0.00355 0.0109
19 0.0035 0.0095
