## Gradient Bidding.

In [6]:
import tensorflow as tf
import types
from utils import noisy_top_k_gating
from utils import SparseDispatcher
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)

Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz


In [7]:

# FF NN with biases.
def expert(i, x, hparams):
    with tf.compat.v1.variable_scope("expert"):
        sizes = [hparams.n_inputs] + [hparams.e_hidden for _ in range(hparams.e_layers)] + [hparams.n_embedding]
        for i in range(len(sizes) - 1):
            w = tf.Variable(tf.truncated_normal([sizes[i], sizes[i+1]], stddev=0.1))
            b = tf.Variable(tf.constant(0.1, shape=[sizes[i+1]]))
            x = tf.matmul(x, w) + b
    return x

# Cross entropy loss + accuracy.
def target_loss(embedding, targets, hparams):
    with tf.compat.v1.variable_scope("target_loss"):
        w = tf.Variable(tf.truncated_normal([hparams.n_embedding, hparams.n_targets], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[hparams.n_targets])),
        logits = tf.add(tf.matmul(embedding, w), b)
        target_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=targets, logits=logits))
        correct = tf.equal(tf.argmax(logits, 1), tf.argmax(targets, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
        return target_loss, accuracy

In [8]:
# Incentive function inputs weights, outputs revenue.
# This the most basic, just takes the inloop weight as your score
def incentive_fn(weights, hparams):
    weights = tf.linalg.normalize(weights)
    return tf.slice(weights, [0], [1])

In [9]:
def model_fn(hparams):    
    x_inputs = tf.placeholder("float", [None, hparams.n_inputs], 'inputs')
    y_targets = tf.placeholder("float", [None, hparams.n_targets], 'targets')
    
    
    # Sparsely gated mixture of experts with choice k. Produces an importance score 
    # for each x_input then chooses the topk. These children recieve the outgoing query.
    # expert_inputs is a list of tensors, inputs for each expert.
    gates, load = noisy_top_k_gating(x_inputs, hparams.n_experts, train = True, k = hparams.k)
    dispatcher = SparseDispatcher(hparams.n_experts, gates)
    expert_inputs = dispatcher.dispatch(x_inputs)
    

    # Basic importance scores can attained from the gating network by summing over the importance 
    # of each example. We choose a 'self-importance' score here which counts as the in loop in our
    # incentive function. The network should try to maximize this value.
    importance = tf.linalg.normalize(tf.reduce_sum(gates, 0))[0]
    self_weight = tf.Variable(tf.constant([1.0]))
    weights = tf.linalg.normalize(tf.concat([self_weight, importance], axis=0))[0]
    revenue = tf.slice(weights, [0], [1])
    
    # Dispatch the inputs to the experts. We mask the responses with a faux-bidding system,
    # here, we set a mask w.r.t the bids with a hparams.market_shift shifted relu. Bids that
    # drop bellow the market shift should zero out.
    expert_outputs = []
    expert_masks = []
    for i in range(hparams.n_experts):
        expert_output = expert(i, expert_inputs[i], hparams)
        
        # Apply mask to the output.
        expert_mask = tf.nn.relu(tf.slice(weights, [i], [1]) - hparams.market_shift)
        masked_output = expert_mask * expert_output
        
        expert_masks.append(expert_mask)
        expert_outputs.append(masked_output)
    expert_masks = tf.concat(expert_masks, axis=0)

    
    # Combine the expert_inputs.
    embedding = dispatcher.combine(expert_outputs)
        
    # Loss and accuracy stuff.
    loss, accuracy = target_loss(embedding, y_targets, hparams)
    
    # Run the step: optimize for loss and revenue. 
    train_step = tf.train.AdamOptimizer(hparams.learning_rate).minimize(loss - revenue)
    
    metrics = {
        'loss': loss,
        'revenue': revenue,
        'accuracy': accuracy,
        'importance': importance,
        'weights': weights,
        'masks': expert_masks,
    }
    return train_step, metrics

In [10]:
hparams = types.SimpleNamespace( 
    n_inputs = 784,
    n_targets = 10,
    k = 3,
    n_experts = 3,
    e_layers = 2,
    e_hidden = 256,
    n_embedding = 256,
    batch_size=256,
    learning_rate=1e-3,
    n_iterations = 10000,
    n_print = 100,
    market_shift = 0.2,
)

graph = tf.Graph()
session = tf.Session(graph=graph)
with graph.as_default():
    train_step, metrics = model_fn(hparams)
    session.run(tf.global_variables_initializer())

for i in range(hparams.n_iterations):
    batch_x, batch_y = mnist.train.next_batch(hparams.batch_size)
    feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
    session.run(train_step, feeds)

    if i % hparams.n_print == 0:
        feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
        train_metrics = session.run(metrics, feeds)
        for key in train_metrics:
            print (str(key) + ":  " + str(train_metrics[key]))
        print ('-')


loss:  1.8619063
revenue:  [0.7074601]
accuracy:  0.48046875
importance:  [0.53903747 0.57394654 0.61646086]
weights:  [0.7074601  0.3809665  0.4056386  0.43568575]
masks:  [0.5074601  0.18096651 0.2056386 ]
-
loss:  0.23665716
revenue:  [0.7398423]
accuracy:  0.9375
importance:  [0.6713292  0.72482896 0.15472582]
weights:  [0.7398423  0.4516571  0.48765066 0.10409649]
masks:  [0.5398423  0.25165707 0.28765064]
-
loss:  0.20587385
revenue:  [0.765721]
accuracy:  0.93359375
importance:  [0.5428449  0.80560696 0.2373118 ]
weights:  [0.765721   0.34914306 0.5181445  0.15263249]
masks:  [0.56572104 0.14914306 0.3181445 ]
-
loss:  0.1442188
revenue:  [0.78689605]
accuracy:  0.96484375
importance:  [0.6238978  0.73202115 0.27367246]
weights:  [0.78689605 0.38499832 0.4517197  0.16887933]
masks:  [0.58689606 0.18499832 0.2517197 ]
-
loss:  0.19727267
revenue:  [0.80496234]
accuracy:  0.9453125
importance:  [0.5529173  0.75980455 0.3420225 ]
weights:  [0.80496234 0.32806012 0.45081168 0.202930

KeyboardInterrupt: 