In [1]:
import os
import sys
import time
import numpy as np
import types
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [3]:
def expert(i, x, hparams):
    with tf.compat.v1.variable_scope("expert"):
        sizes = [hparams.n_inputs] + [hparams.e_hidden for _ in range(hparams.e_layers)] + [hparams.e_output]
        for i in range(len(sizes) - 1):
            w = tf.Variable(tf.truncated_normal([sizes[i], sizes[i+1]], stddev=0.1))
            b = tf.Variable(tf.constant(0.1, shape=[sizes[i+1]]))
            x = tf.matmul(x, w) + b
    return x

def student(x, hparams):
    with tf.compat.v1.variable_scope("student"):
        sizes = [hparams.n_inputs] +[hparams.s_hidden for _ in range(hparams.s_layers)] + [hparams.n_embedding]
        for i in range(len(sizes) - 1):
            w = tf.Variable(tf.truncated_normal([sizes[i], sizes[i+1]], stddev=0.1))
            b = tf.Variable(tf.constant(0.1, shape=[sizes[i+1]]))
            x = tf.matmul(x, w) + b
    return x

In [4]:
def encoder(expert_concat, hparams):
    with tf.compat.v1.variable_scope("encoder"):
        w1 = tf.Variable(tf.truncated_normal([hparams.e_output * hparams.n_experts, hparams.n_embedding], stddev=0.1), name='w')
        b1 = tf.Variable(tf.constant(0.1, shape=[hparams.n_embedding]), name='b')
        return tf.matmul(expert_concat, w1) + b1
                         
def decoder(embedding, hparams):
    with tf.compat.v1.variable_scope("decoder"):
        w1 = tf.Variable(tf.truncated_normal([hparams.n_embedding, hparams.e_output * hparams.n_experts], stddev=0.1), name='w')
        b1 = tf.Variable(tf.constant(0.1, shape=[hparams.e_output * hparams.n_experts]), name='b')
        return tf.matmul(embedding, w1) + b1

In [13]:

def distillation_loss(student, embedding, hparams):
    with tf.compat.v1.variable_scope("distillation_loss"):
        distillation_loss = tf.reduce_mean(
            tf.nn.l2_loss(tf.stop_gradient(embedding) - student))
        return distillation_loss
    
def autoencoder_loss(decoded_output, expert_concat, hparams):
    with tf.compat.v1.variable_scope("autoencoder_loss"):
        distillation_loss = tf.reduce_mean(
            tf.nn.l2_loss(tf.stop_gradient(expert_concat) - decoded_output))
    return distillation_loss

def target_loss(embedding, targets, hparams):
    with tf.compat.v1.variable_scope("target_loss"):
        w = tf.Variable(tf.truncated_normal([hparams.n_embedding, hparams.n_targets], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[hparams.n_targets])),
        logits = tf.add(tf.matmul(embedding, w), b)
        target_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=targets, logits=logits))
        correct = tf.equal(tf.argmax(logits, 1), tf.argmax(targets, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
        return target_loss, accuracy

In [27]:

def model_fn(hparams):    
    x_inputs = tf.placeholder("float", [None, hparams.n_inputs], 'inputs')
    y_targets = tf.placeholder("float", [None, hparams.n_targets], 'targets')

    experts = []
    for ei in range(hparams.n_experts):
        expert_output = expert(ei, x_inputs, hparams)
        experts.append(expert_output)
    expert_concat = tf.concat(experts, axis = 1)
    
    embedding = encoder(expert_concat, hparams)
    embedding_norm = tf.reduce_mean(tf.linalg.normalize(embedding, axis=1))

    student_output = student(x_inputs, hparams)

    decoded_output = decoder(embedding, hparams)
    
    dist_loss = distillation_loss(student_output, embedding, hparams)
    
    auto_loss = autoencoder_loss(decoded_output, expert_concat, hparams)
    
    student_loss, student_accuracy = target_loss(student_output, y_targets, hparams)
    
    teacher_loss, teacher_accuracy = target_loss(embedding, y_targets, hparams)
    
    full_loss = dist_loss + auto_loss + student_loss + teacher_loss
    
    train_step = tf.train.AdamOptimizer(hparams.learning_rate).minimize(full_loss)

    metrics = {
        'teacher_loss': teacher_loss,
        'student_loss': student_loss,
        'dist_loss': dist_loss,
        'auto_loss': auto_loss,
        'full_loss': full_loss,
        'student_accuracy': student_accuracy,
        'teacher_accuracy': teacher_accuracy,
        'embedding_norm': embedding_norm,
    }
    
    return train_step, metrics

In [None]:
hparams = types.SimpleNamespace(
    batch_size=128,
    learning_rate=1e-3,
    n_iterations = 100000,
    n_print = 300,
    n_inputs = 784,
    n_targets = 10,
    n_experts = 20,
    e_layers = 2,
    e_hidden = 256,
    e_output = 50,
    s_hidden = 256,
    s_layers = 2,
    n_embedding = 256,
)

graph = tf.Graph()
session = tf.Session(graph=graph)
with graph.as_default():
    train_step, metrics = model_fn(hparams)
    session.run(tf.global_variables_initializer())

    
for i in range(hparams.n_iterations):
    batch_x, batch_y = mnist.train.next_batch(hparams.batch_size)
    feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
    session.run(train_step, feeds)

    if i % hparams.n_print == 0:
        test_x = mnist.test.images
        test_y = mnist.test.labels
        feeds = {'inputs:0': test_x, 'targets:0': test_y}
        train_metrics = session.run(metrics, feeds)
        for key in train_metrics:
            print (key + ': ' + str(train_metrics[key]))
        print ('')


teacher_loss: 7.8373485
student_loss: 3.2501142
dist_loss: 31634462.0
auto_loss: 244827660.0
full_loss: 276462100.0
student_accuracy: 0.1746
teacher_accuracy: 0.179
embedding_norm: (array([[-1.80408155e-04, -8.97849241e-05,  3.80441692e-04, ...,
        -6.83656253e-05, -6.18652848e-04,  5.32463077e-04],
       [-3.44881089e-04, -2.87002738e-04,  4.28492276e-05, ...,
         9.86690982e-04, -4.91664628e-04,  9.92592541e-04],
       [-1.15620925e-04, -5.09959296e-04, -1.21191160e-05, ...,
         7.98441833e-05, -5.32699516e-04,  1.90781051e-04],
       ...,
       [ 7.63469798e-05, -7.67544319e-04, -2.09683683e-04, ...,
         2.30925980e-05, -6.13551354e-04,  5.26093470e-04],
       [ 5.18808898e-04, -3.01958411e-04,  3.82867031e-04, ...,
         6.21166313e-04, -1.09477369e-04,  4.01121069e-04],
       [-6.47655455e-04, -3.33181786e-04,  9.87649400e-05, ...,
         3.24043678e-04, -1.48108497e-03,  2.29417207e-03]], dtype=float32), array([[7352.762]], dtype=float32))

teacher_