In [47]:
import numpy
import tensorflow as tf
import types
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)
numpy.set_printoptions(precision=3, suppress=True)

Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz


In [48]:
def ffnn(x, hparams):
    sizes = [hparams.n_inputs] + [hparams.n_hidden for _ in range(hparams.n_layers)] + [hparams.n_targets]
    for i in range(len(sizes) - 1):
        w = tf.Variable(tf.truncated_normal(sizes[i:i+2], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[sizes[i+1]]))
        x = tf.matmul(x, w) + b
        
        shift = tf.reduce_mean(w, axis=1, keepdims=True)
        relative = w = tf.tile(shifts, [hparams.n_clients, 1])
        
        shifts = tf.expand_dims(tf.reduce_mean(weights, axis=0), 0)
    relative = weights - tf.tile(shifts, [hparams.n_clients, 1]) + hparams.market_shift * tf.ones([hparams.n_clients, 1 + hparams.n_experts]) 
    masks = tf.nn.relu(relative)
    
    return x


In [89]:

def model_fn(hparams):    
    inputs = tf.placeholder("float", [None, hparams.n_inputs], 'inputs')
    targets = tf.placeholder("float", [None, hparams.n_targets], 'targets')   
    
    masks = []
    x = inputs
    sizes = [hparams.n_inputs] + [hparams.n_hidden for _ in range(hparams.n_layers)] + [hparams.n_targets]
    for i in range(len(sizes) - 1):
        
        # Declare layer weights.
        w = tf.Variable(tf.truncated_normal(sizes[i:i+2], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[sizes[i+1]]))
        
        # Mask the weights using a mean shifted relu. 
        # Weights are pushed above the mean if they are useful.  
        shift = tf.reduce_mean(w, axis=1, keepdims=True)
        relative = w - tf.tile(shift, [1, sizes[i+1]])
        mask = tf.nn.relu(relative)
        masks.append(mask)
        
        # Apply the mask to the weights.
        w = tf.multiply(w, mask)
        
        # Use the weights.
        x = tf.matmul(x, w) + b
        
    logits = x
    
    norms = [tf.reduce_sum(m) for m in masks]
    norm_sum = tf.add_n(norms)
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=targets, logits=logits))
    correct = tf.equal(tf.argmax(logits, 1), tf.argmax(targets, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    
    full_loss = loss # + (hparams.alpha * norm_sum)
    
    train_step = tf.train.AdamOptimizer(hparams.learning_rate).minimize(full_loss)
    
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
        'norms': norms,
        'norm_sum': norm_sum,
    }
    for i,m in enumerate(masks):
        metrics['mask_' + str(i)] = masks[i]
        
    return train_step, metrics

In [90]:
hparams = types.SimpleNamespace( 
    batch_size=256,
    learning_rate=1e-3,
    n_inputs = 784,
    n_targets = 10,
    n_layers = 2,
    n_hidden = 256,
    n_iterations = 10000,
    n_print = 100,
    alpha = 0.00001,
)

graph = tf.Graph()
session = tf.Session(graph=graph)
with graph.as_default():
    train_step, metrics = model_fn(hparams)
    session.run(tf.global_variables_initializer())

for i in range(hparams.n_iterations):
    batch_x, batch_y = mnist.train.next_batch(hparams.batch_size)
    feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
    session.run(train_step, feeds)

    if i % hparams.n_print == 0:
        feeds = {'inputs:0': batch_x, 'targets:0': batch_y}
        train_metrics = session.run(metrics, feeds)
        for key in train_metrics:
            print (str(key) + ":  ")
            print (str(train_metrics[key]))
        print ('-')

loss:  
2.2978823
accuracy:  
0.14453125
norms:  
[7245.0977, 2361.9722, 86.69383]
norm_sum:  
9693.764
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.097 0.038 ... 0.    0.    0.12 ]
 [0.054 0.04  0.068 ... 0.005 0.    0.   ]
 [0.    0.    0.021 ... 0.143 0.063 0.   ]
 ...
 [0.    0.    0.047 ... 0.    0.037 0.   ]
 [0.175 0.    0.    ... 0.    0.    0.   ]
 [0.036 0.    0.087 ... 0.    0.    0.021]]
mask_2:  
[[0.    0.    0.077 ... 0.071 0.154 0.   ]
 [0.    0.    0.039 ... 0.    0.    0.04 ]
 [0.066 0.09  0.008 ... 0.104 0.    0.   ]
 ...
 [0.064 0.    0.186 ... 0.049 0.    0.   ]
 [0.053 0.    0.09  ... 0.003 0.028 0.029]
 [0.065 0.    0.    ... 0.    0.    0.074]]
-
loss:  
1.8556762
accuracy:  
0.51953125
norms:  
[9590.967, 3792.6543, 

loss:  
0.25283784
accuracy:  
0.9296875
norms:  
[16371.221, 6305.754, 168.76068]
norm_sum:  
22845.734
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.298 0.    ... 0.    0.    0.4  ]
 [0.191 0.031 0.067 ... 0.053 0.    0.   ]
 [0.    0.    0.    ... 0.396 0.072 0.   ]
 ...
 [0.    0.    0.069 ... 0.    0.22  0.   ]
 [0.502 0.    0.    ... 0.    0.    0.   ]
 [0.216 0.    0.    ... 0.    0.    0.227]]
mask_2:  
[[0.    0.    0.113 ... 0.015 0.296 0.   ]
 [0.    0.    0.049 ... 0.    0.    0.032]
 [0.083 0.089 0.02  ... 0.242 0.    0.   ]
 ...
 [0.101 0.    0.379 ... 0.028 0.    0.   ]
 [0.123 0.    0.226 ... 0.    0.106 0.   ]
 [0.143 0.    0.    ... 0.    0.    0.04 ]]
-
loss:  
0.28649122
accuracy:  
0.921875
norms:  
[16572.328, 6334.974, 

loss:  
0.18845114
accuracy:  
0.94921875
norms:  
[17784.95, 6525.6304, 173.08224]
norm_sum:  
24483.662
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.316 0.    ... 0.    0.    0.421]
 [0.184 0.041 0.069 ... 0.061 0.    0.   ]
 [0.    0.    0.    ... 0.421 0.081 0.   ]
 ...
 [0.    0.    0.074 ... 0.    0.237 0.   ]
 [0.525 0.    0.    ... 0.    0.    0.   ]
 [0.246 0.    0.    ... 0.    0.    0.244]]
mask_2:  
[[0.    0.    0.114 ... 0.008 0.307 0.   ]
 [0.    0.    0.043 ... 0.    0.    0.035]
 [0.075 0.083 0.026 ... 0.263 0.    0.   ]
 ...
 [0.1   0.    0.39  ... 0.042 0.    0.   ]
 [0.135 0.    0.235 ... 0.    0.104 0.   ]
 [0.146 0.    0.    ... 0.    0.    0.039]]
-
loss:  
0.2850386
accuracy:  
0.91015625
norms:  
[17858.562, 6537.655

loss:  
0.23677953
accuracy:  
0.92578125
norms:  
[18671.266, 6635.899, 174.96149]
norm_sum:  
25482.125
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.34  0.    ... 0.    0.    0.447]
 [0.18  0.048 0.068 ... 0.063 0.    0.   ]
 [0.    0.    0.    ... 0.44  0.089 0.   ]
 ...
 [0.    0.    0.079 ... 0.    0.246 0.   ]
 [0.545 0.    0.    ... 0.    0.    0.   ]
 [0.269 0.    0.    ... 0.    0.    0.256]]
mask_2:  
[[0.    0.    0.115 ... 0.004 0.313 0.   ]
 [0.    0.    0.04  ... 0.    0.    0.038]
 [0.068 0.075 0.039 ... 0.27  0.    0.   ]
 ...
 [0.094 0.    0.398 ... 0.049 0.    0.   ]
 [0.14  0.    0.24  ... 0.    0.103 0.   ]
 [0.144 0.    0.    ... 0.    0.    0.042]]
-
loss:  
0.23389672
accuracy:  
0.94140625
norms:  
[18768.809, 6650.57

loss:  
0.20567575
accuracy:  
0.921875
norms:  
[19380.031, 6714.701, 176.21976]
norm_sum:  
26270.953
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.364 0.    ... 0.    0.    0.469]
 [0.181 0.055 0.066 ... 0.063 0.    0.   ]
 [0.    0.    0.    ... 0.45  0.099 0.   ]
 ...
 [0.    0.    0.084 ... 0.    0.251 0.   ]
 [0.554 0.    0.    ... 0.    0.    0.   ]
 [0.281 0.    0.    ... 0.    0.    0.26 ]]
mask_2:  
[[0.    0.    0.117 ... 0.003 0.314 0.   ]
 [0.    0.    0.038 ... 0.    0.    0.038]
 [0.063 0.067 0.05  ... 0.279 0.    0.   ]
 ...
 [0.089 0.    0.404 ... 0.052 0.    0.   ]
 [0.143 0.    0.244 ... 0.    0.101 0.   ]
 [0.147 0.    0.    ... 0.    0.    0.042]]
-
loss:  
0.25095695
accuracy:  
0.9296875
norms:  
[19472.332, 6730.794, 

loss:  
0.3201241
accuracy:  
0.9140625
norms:  
[19986.168, 6776.5527, 177.09837]
norm_sum:  
26939.818
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.382 0.    ... 0.    0.    0.49 ]
 [0.177 0.061 0.067 ... 0.064 0.    0.   ]
 [0.    0.    0.    ... 0.459 0.106 0.   ]
 ...
 [0.    0.    0.088 ... 0.    0.256 0.   ]
 [0.569 0.    0.    ... 0.    0.    0.   ]
 [0.297 0.    0.    ... 0.    0.    0.262]]
mask_2:  
[[0.    0.    0.119 ... 0.002 0.318 0.   ]
 [0.    0.    0.037 ... 0.    0.    0.038]
 [0.06  0.063 0.055 ... 0.285 0.    0.   ]
 ...
 [0.083 0.    0.408 ... 0.054 0.    0.   ]
 [0.143 0.    0.245 ... 0.    0.1   0.   ]
 [0.147 0.    0.    ... 0.    0.    0.044]]
-
loss:  
0.2702574
accuracy:  
0.93359375
norms:  
[20058.68, 6789.6704,

loss:  
0.2894395
accuracy:  
0.91796875
norms:  
[20529.252, 6837.5312, 177.96907]
norm_sum:  
27544.752
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.396 0.    ... 0.    0.    0.505]
 [0.181 0.066 0.066 ... 0.066 0.    0.   ]
 [0.    0.    0.    ... 0.468 0.112 0.   ]
 ...
 [0.    0.    0.093 ... 0.    0.258 0.   ]
 [0.579 0.    0.    ... 0.    0.    0.   ]
 [0.305 0.    0.    ... 0.    0.    0.262]]
mask_2:  
[[0.    0.    0.119 ... 0.003 0.32  0.   ]
 [0.    0.    0.038 ... 0.    0.    0.038]
 [0.06  0.06  0.058 ... 0.287 0.    0.   ]
 ...
 [0.078 0.    0.412 ... 0.056 0.    0.   ]
 [0.144 0.    0.245 ... 0.    0.099 0.   ]
 [0.151 0.    0.    ... 0.    0.    0.045]]
-
loss:  
0.24424222
accuracy:  
0.90625
norms:  
[20560.18, 6836.229, 1

loss:  
0.18521169
accuracy:  
0.94921875
norms:  
[21005.277, 6891.5723, 178.6466]
norm_sum:  
28075.496
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.411 0.    ... 0.    0.    0.521]
 [0.182 0.068 0.066 ... 0.067 0.    0.   ]
 [0.    0.    0.    ... 0.472 0.116 0.   ]
 ...
 [0.    0.    0.096 ... 0.    0.261 0.   ]
 [0.594 0.    0.    ... 0.    0.    0.   ]
 [0.322 0.    0.    ... 0.    0.    0.263]]
mask_2:  
[[0.    0.    0.121 ... 0.003 0.323 0.   ]
 [0.    0.    0.04  ... 0.    0.    0.041]
 [0.059 0.06  0.061 ... 0.29  0.    0.   ]
 ...
 [0.073 0.    0.418 ... 0.058 0.    0.   ]
 [0.145 0.    0.249 ... 0.    0.1   0.   ]
 [0.153 0.    0.    ... 0.    0.    0.046]]
-
loss:  
0.22009775
accuracy:  
0.94140625
norms:  
[21054.99, 6897.26,

loss:  
0.1666013
accuracy:  
0.93359375
norms:  
[21453.531, 6935.8447, 179.25024]
norm_sum:  
28568.625
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.422 0.    ... 0.    0.    0.529]
 [0.181 0.068 0.067 ... 0.07  0.    0.   ]
 [0.    0.    0.    ... 0.474 0.12  0.   ]
 ...
 [0.    0.    0.101 ... 0.    0.265 0.   ]
 [0.607 0.    0.    ... 0.    0.    0.   ]
 [0.333 0.    0.    ... 0.    0.    0.257]]
mask_2:  
[[0.    0.    0.123 ... 0.004 0.324 0.   ]
 [0.    0.    0.041 ... 0.    0.    0.041]
 [0.059 0.06  0.061 ... 0.296 0.    0.   ]
 ...
 [0.069 0.    0.423 ... 0.06  0.    0.   ]
 [0.144 0.    0.25  ... 0.    0.099 0.   ]
 [0.154 0.    0.    ... 0.    0.    0.047]]
-
loss:  
0.29322177
accuracy:  
0.8984375
norms:  
[21500.453, 6946.227

loss:  
0.26385218
accuracy:  
0.9296875
norms:  
[21833.162, 6980.857, 179.79031]
norm_sum:  
28993.81
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.435 0.    ... 0.    0.    0.543]
 [0.183 0.07  0.068 ... 0.071 0.    0.   ]
 [0.    0.    0.    ... 0.475 0.124 0.   ]
 ...
 [0.    0.    0.104 ... 0.    0.269 0.   ]
 [0.617 0.    0.    ... 0.    0.    0.   ]
 [0.338 0.    0.    ... 0.    0.    0.259]]
mask_2:  
[[0.    0.    0.123 ... 0.005 0.326 0.   ]
 [0.    0.    0.04  ... 0.    0.    0.041]
 [0.06  0.059 0.061 ... 0.301 0.    0.   ]
 ...
 [0.067 0.    0.425 ... 0.061 0.    0.   ]
 [0.143 0.    0.251 ... 0.    0.096 0.   ]
 [0.156 0.    0.    ... 0.    0.    0.047]]
-
loss:  
0.30636847
accuracy:  
0.9140625
norms:  
[21863.303, 6982.1826,

loss:  
0.21955419
accuracy:  
0.91796875
norms:  
[22180.752, 7020.597, 180.1895]
norm_sum:  
29381.54
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.449 0.    ... 0.    0.    0.558]
 [0.18  0.073 0.072 ... 0.072 0.    0.   ]
 [0.    0.    0.    ... 0.481 0.125 0.   ]
 ...
 [0.    0.    0.111 ... 0.    0.268 0.   ]
 [0.626 0.    0.    ... 0.    0.    0.   ]
 [0.348 0.    0.    ... 0.    0.    0.255]]
mask_2:  
[[0.    0.    0.123 ... 0.004 0.328 0.   ]
 [0.    0.    0.04  ... 0.    0.    0.04 ]
 [0.062 0.06  0.062 ... 0.303 0.    0.   ]
 ...
 [0.065 0.    0.427 ... 0.061 0.    0.   ]
 [0.141 0.    0.248 ... 0.    0.096 0.   ]
 [0.16  0.    0.    ... 0.    0.    0.046]]
-
loss:  
0.25024116
accuracy:  
0.92578125
norms:  
[22220.512, 7026.2656

loss:  
0.23936945
accuracy:  
0.953125
norms:  
[22514.475, 7060.549, 180.64021]
norm_sum:  
29755.664
mask_0:  
[[0.    0.    0.    ... 0.    0.022 0.   ]
 [0.    0.    0.    ... 0.043 0.039 0.   ]
 [0.    0.087 0.    ... 0.045 0.121 0.   ]
 ...
 [0.    0.2   0.    ... 0.134 0.06  0.092]
 [0.    0.    0.    ... 0.    0.135 0.076]
 [0.    0.016 0.    ... 0.    0.063 0.138]]
mask_1:  
[[0.    0.459 0.    ... 0.    0.    0.564]
 [0.183 0.075 0.074 ... 0.073 0.    0.   ]
 [0.    0.    0.    ... 0.484 0.126 0.   ]
 ...
 [0.    0.    0.113 ... 0.    0.273 0.   ]
 [0.636 0.    0.    ... 0.    0.    0.   ]
 [0.357 0.    0.    ... 0.    0.    0.252]]
mask_2:  
[[0.    0.    0.125 ... 0.005 0.33  0.   ]
 [0.    0.    0.041 ... 0.    0.    0.041]
 [0.062 0.061 0.064 ... 0.305 0.    0.   ]
 ...
 [0.064 0.    0.434 ... 0.063 0.    0.   ]
 [0.138 0.    0.253 ... 0.    0.096 0.   ]
 [0.158 0.    0.    ... 0.    0.    0.049]]
-
