In [10]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Dirs - must be absolute paths!
LOG_DIR = '/tmp/tf/mnist_logistic_regression'
MNIST_DIR = "/home/tkornuta/data/mnist"

# Set learning parameters.
LEARNING_RATE = 1e-4
BATCH_SIZE = 100
N_EPOCHS = 1

### A. Import MNIST datset, use one-hot encoding for labels.

In [2]:
mnist_dataset = input_data.read_data_sets(MNIST_DIR, one_hot=True)

Extracting /home/tkornuta/data/mnist/train-images-idx3-ubyte.gz
Extracting /home/tkornuta/data/mnist/train-labels-idx1-ubyte.gz
Extracting /home/tkornuta/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /home/tkornuta/data/mnist/t10k-labels-idx1-ubyte.gz


### B. Start graph definition.

In [3]:
# B. 
# 0. Placeholders for inputs.
with tf.name_scope("Input_data"):
    # Shape - none, as we will feed both training batches as well as test datasets.
    x = tf.placeholder(tf.float32, shape=None, name="x")
    targets = tf.placeholder(tf.float32, shape=None, name="target")
with tf.name_scope("Input_visualization") as scope:
  x_image = tf.reshape(x, [-1,28,28,1])
  image_summ = tf.summary.image("Example_images", x_image)

# 1. Inference ops.
with tf.name_scope("Inference"):
    w = tf.Variable(tf.random_normal(shape=[784,10], stddev=0.01), name="weights") 
    b = tf.Variable(tf.zeros(shape=[1,10]), name="bias")
    logits = tf.add(tf.matmul(x,w), b, name="logits")
    # Add histograms to TensorBoard.
    w_hist = tf.summary.histogram("w", w)
    b_hist = tf.summary.histogram("b", b)
with tf.name_scope("Activation_visualization") as scope:
  x_image = tf.reshape(x, [-1,28,28,1])
  image_summ = tf.summary.image("Example_images", x_image)

# 2. Loss ops.
with tf.name_scope("Loss"):
    entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)
    # Loss = mean over examples in the batch.
    loss = tf.reduce_mean(entropy)
    # Add loss summary.
    loss_summary = tf.summary.scalar("loss", loss)
    
# 3. Training ops.  
with tf.name_scope("Training"):
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)
with tf.name_scope("Evaluating") as scope:
    # Count correct predictions by a simple argmax trick on each sample in a batch.
    correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(targets,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    # Add accuracy summary.
    accuracy_summary = tf.summary.scalar("accuracy", accuracy)

# Merge all summaries.
summaries = tf.summary.merge_all()

# 4. Init global variable.
init = tf.global_variables_initializer()

#### Helper functions

In [4]:
def feed_dict(data_set):
  """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
  if dataset=="train":
    xs, ys = mnist_dataset.train.next_batch(BATCH_SIZE)
  else if dataset=="valid":
    xs, ys = mnist_dataset.valid.images, mnist_dataset.valid.labels
  else: # test
    xs, ys = mnist_dataset.test.images, mnist_dataset.test.labels
  return {x: xs, targets: ys}

### C. Run session.

In [13]:
# Create session.
sess = tf.InteractiveSession()
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(LOG_DIR + '/train', sess.graph)
valid_writer = tf.summary.FileWriter(LOG_DIR + '/valid')

# Initialize variables.
#tf.global_variables_initializer().run()
sess.run(init)

n_batches = int(mnist_dataset.train.num_examples/BATCH_SIZE)
for e in range (N_EPOCHS):
  for b in range (n_batches):
    if b % 100 == 0:  # Record summaries and valid-set accuracy
      summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict("train"))
      test_writer.add_summary(summary, b+e*n_batches)
      print('Accuracy at step %s: %s' % (b+e*n_batches, acc))
    else:  # Record train set summaries, and train
      summary, acc = sess.run([merged, optimizer], feed_dict=feed_dict("Valid"))
      valid_writer.add_summary(summary, b+e*n_batches)

# Finally, check accuray on test dataset
acc = sess.run(accuracy, feed_dict=feed_dict("test"))
print('Final accuracy on test set: %s' % (acc))

# Close writers and session.
test_writer.flush()
test_writer.close()
valid_writer.flush()
valid_writer.close()
sess.close()

Accuracy at step 0: 0.08
Accuracy at step 100: 0.63
Accuracy at step 200: 0.76
Accuracy at step 300: 0.84
Accuracy at step 400: 0.86
Accuracy at step 500: 0.84
Final accuracy on test set: 0.8
