# WeightWatcher demo (TensorFlow)

In [None]:
from __future__ import absolute_import, division, print_function
import sys
import tensorflow as tf

from ww import builder_tf, watcher
from tf_cifar10 import CIFAR10

In [None]:
# TODO: Remove in final version. Auto reloads imported libraries if they change
%load_ext autoreload
%autoreload 2

## Visualize Graph

In [None]:
# CIFAR10 dataset and model
batch_size = 128
cifar10 = CIFAR10(batch_size=batch_size)

In [None]:
# Inspect data and labels
watcher.show("cifar10.train_data", cifar10.train_data)
watcher.show("cifar10.train_labels", cifar10.train_labels)
watcher.show("cifar10.test_data", cifar10.test_data)
watcher.show("cifar10.test_labels", cifar10.test_labels)

In [None]:
# Setup TF "graphing" session
sess = tf.Session()

In [None]:
# Setup placeholders/vars
inputs = tf.placeholder(tf.float32, shape=(batch_size, cifar10.img_size, cifar10.img_size, cifar10.num_channels))

In [None]:
# Build model
predictions = cifar10.model(inputs)

In [None]:
# Run the initializer
sess.run(tf.global_variables_initializer())

In [None]:
# Convert TF graph to directed graph
dg = builder_tf.build_tf_graph(tf.get_default_graph(), sess, predictions.op.name) # Nodes (78)

In [None]:
# Draw full graph
dg.draw_graph(simplify=True, output_shapes=True, verbose=False)

In [None]:
# Terminate "graphing" session
sess.close()
tf.reset_default_graph()

## Visualize Training Progress

In [None]:
# Setup TF training session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)

In [None]:
# Setup placeholders/vars
inputs = tf.placeholder(tf.float32, shape=(batch_size, cifar10.img_size, cifar10.img_size, cifar10.num_channels))
outputs = tf.placeholder(tf.float32, shape=[batch_size, cifar10.num_classes])
g_step = tf.Variable(initial_value=0, trainable=False)

In [None]:
# Build model
predictions = cifar10.model(inputs)

In [None]:
# Setup loss and optimizer
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=predictions, labels=outputs))
optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9).minimize(loss, global_step=g_step)

In [None]:
# Setup metric
accurate_preds = tf.equal(tf.argmax(predictions, axis=1), tf.argmax(outputs, axis=1))
accuracy = tf.reduce_mean(tf.cast(accurate_preds, tf.float32))

In [None]:
# Instantiate watcher
w = watcher.Watcher()

In [None]:
# Visual customizations
w.legend={"loss": "Training Loss", "accuracy": "Training Accuracy"}

In [None]:
# Run the initializer
sess.run(tf.global_variables_initializer())

In [None]:
# Run training loop on GPU
epochs = 4
with tf.device('/gpu:0'): # Set to '/cpu:0' if you don't have a GPU
    for epoch in range(epochs):

        batches, _ = divmod(cifar10.train_len, batch_size)
        for batch in range(batches):

            # Fetch training samples
            _input = cifar10.train_data[batch*batch_size : (batch+1)*batch_size]
            _output = cifar10.train_labels[batch*batch_size : (batch+1)*batch_size]

            # Train model
            train_ops = [g_step, optimizer, loss, accuracy]
            step, _, _loss, _accuracy = sess.run(train_ops, feed_dict={inputs : _input, outputs : _output})
            
            # Print stats
            if batch & batch % 100 == 0:
                _weights = tf.get_default_graph().get_tensor_by_name('conv1/conv2d/kernel:0').eval(session=sess)
                w.step(step, loss=_loss, accuracy=_accuracy, conv1_weights=_weights)
                with w:
                    w.plot(["loss"])
                    w.plot(["accuracy"])
                    w.hist(["conv1_weights"])

In [None]:
# Terminate training session
sess.close()