## Gradient Bidding.

In [1]:
import tensorflow as tf
import types
from utils import noisy_top_k_gating
from utils import SparseDispatcher
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [2]:
def expert(i, x, hparams):
    with tf.compat.v1.variable_scope("expert"):
        sizes = [hparams.n_inputs] + [hparams.e_hidden for _ in range(hparams.e_layers)] + [hparams.n_embedding]
        for i in range(len(sizes) - 1):
            w = tf.Variable(tf.truncated_normal([sizes[i], sizes[i+1]], stddev=0.1))
            b = tf.Variable(tf.constant(0.1, shape=[sizes[i+1]]))
            x = tf.matmul(x, w) + b
    return x

def target_loss(embedding, targets, hparams):
    with tf.compat.v1.variable_scope("target_loss"):
        w = tf.Variable(tf.truncated_normal([hparams.n_embedding, hparams.n_targets], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[hparams.n_targets])),
        logits = tf.add(tf.matmul(embedding, w), b)
        target_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=targets, logits=logits))
        correct = tf.equal(tf.argmax(logits, 1), tf.argmax(targets, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
        return target_loss, accuracy

In [3]:
def incentive_fn(weights, hparams):
    weights = tf.linalg.normalize(weights)
    return tf.slice(weights, [0], [1])

In [87]:
def model_fn(hparams):    
    x_inputs = tf.placeholder("float", [None, hparams.n_inputs], 'inputs')
    y_targets = tf.placeholder("float", [None, hparams.n_targets], 'targets')
    
    gates, load = noisy_top_k_gating(x_inputs, hparams.n_experts, train = True, k = hparams.k)
    
    importance = tf.linalg.normalize(tf.reduce_sum(gates, 0))[0]
    self_weight = tf.Variable(tf.constant([1.0]))
    weights = tf.linalg.normalize(tf.concat([self_weight, importance], axis=0))[0]
    revenue = tf.slice(weights, [0], [1])

    dispatcher = SparseDispatcher(hparams.n_experts, gates)

    expert_inputs = dispatcher.dispatch(x_inputs)
    
    expert_outputs = []
    expert_masks = []
    for i in range(hparams.n_experts):
        
        expert_output = expert(i, expert_inputs[i], hparams)
        expert_mask = tf.nn.relu(tf.slice(weights, [i], [1]) - 0.2)
        masked_output = expert_mask * expert_output
        
        expert_masks.append(expert_mask)
        expert_outputs.append(masked_output)
        
    expert_masks = tf.concat(expert_masks, axis=0)

    embedding = dispatcher.combine(expert_outputs)
    
    embedding = tf.Print(embedding, [tf.shape(embedding)])
    
    loss, accuracy = target_loss(embedding, y_targets, hparams)
    
    train_step = tf.train.AdamOptimizer(hparams.learning_rate).minimize(loss - revenue)
    
    metrics = {
        'loss': loss,
        'revenue': revenue,
        'accuracy': accuracy,
        'importance': importance,
        'weights': weights,
        'masks': expert_masks,
    }
    return train_step, metrics

In [88]:
hparams = types.SimpleNamespace( 
    n_inputs = 784,
    n_targets = 10,
    k = 3,
    n_experts = 3,
    e_layers = 2,
    e_hidden = 256,
    n_embedding = 256,
    batch_size=256,
    learning_rate=1e-3,
    n_iterations = 10000,
    n_print = 100,
)

graph = tf.Graph()
session = tf.Session(graph=graph)
with graph.as_default():
    train_step, metrics = model_fn(hparams)
    session.run(tf.global_variables_initializer())

for i in range(hparams.n_iterations):
    batch_x, batch_y = mnist.train.next_batch(hparams.batch_size)
    feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
    session.run(train_step, feeds)

    if i % hparams.n_print == 0:
        feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
        train_metrics = session.run(metrics, feeds)
        for key in train_metrics:
            print (str(key) + ":  " + str(train_metrics[key]))
        print ('-')


loss:  1.8755696
revenue:  [0.7074601]
accuracy:  0.44140625
importance:  [0.5506503  0.55532986 0.62321186]
weights:  [0.7074601  0.38917392 0.3924812  0.44045705]
masks:  [0.5074601  0.18917392 0.1924812 ]
-
loss:  0.27419168
revenue:  [0.74003845]
accuracy:  0.9140625
importance:  [0.5455215  0.83052945 0.11237027]
weights:  [0.74003845 0.36689845 0.55858475 0.07557628]
masks:  [0.54003847 0.16689844 0.35858476]
-
loss:  0.14204794
revenue:  [0.76587707]
accuracy:  0.95703125
importance:  [0.5496849  0.8198732  0.16016997]
weights:  [0.76587707 0.35344025 0.52716786 0.10298721]
masks:  [0.5658771  0.15344025 0.32716787]
-
loss:  0.10527551
revenue:  [0.7869224]
accuracy:  0.9765625
importance:  [0.56753904 0.805232   0.17175774]
weights:  [0.7869224  0.35020107 0.49686998 0.10598345]
masks:  [0.5869224  0.15020107 0.29687   ]
-
loss:  0.27466145
revenue:  [0.80468553]
accuracy:  0.94921875
importance:  [0.62512124 0.75280845 0.20616253]
weights:  [0.80468553 0.37113526 0.4469433  0.

loss:  0.08843975
revenue:  [0.93092257]
accuracy:  0.9765625
importance:  [0.7340931  0.6440999  0.21504082]
weights:  [0.93092257 0.2681031  0.23523606 0.0785365 ]
masks:  [0.7309226  0.06810309 0.03523606]
-
loss:  0.11278808
revenue:  [0.9318372]
accuracy:  0.96484375
importance:  [0.7110174  0.66743946 0.22131185]
weights:  [0.9318372  0.2580115  0.2421981  0.08030887]
masks:  [0.7318372  0.05801149 0.04219809]
-
loss:  0.03994143
revenue:  [0.9320613]
accuracy:  0.98828125
importance:  [0.69293237 0.67361015 0.25708768]
weights:  [0.9320613  0.25104985 0.24404939 0.09314303]
masks:  [0.7320613  0.05104984 0.04404938]
-
loss:  0.07165096
revenue:  [0.93264365]
accuracy:  0.98046875
importance:  [0.63184816 0.71634424 0.296005  ]
weights:  [0.93264365 0.22797015 0.25845626 0.1067983 ]
masks:  [0.73264366 0.02797015 0.05845626]
-
loss:  0.062119935
revenue:  [0.93309575]
accuracy:  0.98046875
importance:  [0.67930204 0.7012797  0.21623042]
weights:  [0.93309575 0.24429601 0.2521998 

loss:  0.11456588
revenue:  [0.9363023]
accuracy:  0.96875
importance:  [0.71936107 0.66935265 0.18570568]
weights:  [0.9363023  0.2526362  0.23507346 0.06521895]
masks:  [0.7363023  0.05263619 0.03507346]
-
loss:  0.08287234
revenue:  [0.9362461]
accuracy:  0.96484375
importance:  [0.6821744  0.6871531  0.24991742]
weights:  [0.9362461  0.2396784  0.24142765 0.08780718]
masks:  [0.7362461  0.03967839 0.04142764]
-
loss:  0.07216907
revenue:  [0.9374479]
accuracy:  0.98046875
importance:  [0.7500242  0.6067304  0.26332846]
weights:  [0.9374479  0.2611026  0.21121836 0.09167136]
masks:  [0.7374479  0.06110258 0.01121835]
-
loss:  0.039332256
revenue:  [0.9373575]
accuracy:  0.98828125
importance:  [0.7247723  0.65841395 0.20296852]
weights:  [0.9373575  0.25248823 0.22937103 0.07070795]
masks:  [0.7373575  0.05248822 0.02937102]
-
loss:  0.07274256
revenue:  [0.9372486]
accuracy:  0.9765625
importance:  [0.70559347 0.66062844 0.25633568]
weights:  [0.9372486  0.24601354 0.23033595 0.089

KeyboardInterrupt: 