In [1]:
import tensorflow as tf
from tensorflow.python.summary.writer.writer import FileWriter
import numpy as np

from models import build_model
tf.compat.v1.reset_default_graph()

n_replicas = 4
train_dict = {}

with tf.name_scope('inputs'):
    inputs = tf.keras.layers.Input(shape=(28, 28, 1), dtype=tf.float32)

with tf.name_scope('targets'):
    targets = tf.keras.layers.Input(shape=(), dtype=tf.int32)

for i in range(n_replicas):
    dropout_rate = tf.compat.v1.placeholder_with_default(0., shape=())
    lr = tf.compat.v1.placeholder(tf.float32, shape=())

    train_dict[i] = {'lr': lr,
                     'dropout_rate': dropout_rate}

    with tf.name_scope('model_' + str(i)):
        train_dict[i]['model'] = build_model(inputs, dropout_rate)
    
    with tf.name_scope('loss_' + str(i)):
        xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=targets, logits=train_dict[i]['model'].outputs[0])
        train_dict[i]['loss'] = tf.reduce_mean(xentropy)
    
    with tf.name_scope('optimizer_' + str(i)):
        grads = tf.gradients(train_dict[i]['loss'],
                             train_dict[i]['model'].trainable_variables)
        grads_and_vars = zip(grads, train_dict[i]['model'].trainable_variables)
        with tf.compat.v1.control_dependencies(grads):
            train_ops = [w.assign(w - lr * g) for g, w in grads_and_vars]
            train_dict[i]['train_op'] = tf.group(train_ops)
    
    with tf.name_scope('error_' + str(i)):
        y_pred = tf.argmax(train_dict[i]['model'].outputs[0], axis=1)
        equals = tf.cast(tf.math.equal(y_pred, y_true), tf.float32)
        train_dict[i]['error'] = 1. - tf.reduce_mean(equals)
FileWriter('logs/train', graph=inputs.graph).close()


W1102 16:02:53.007150 4450002368 deprecation.py:506] From /Users/vpushkarov/anaconda3/envs/tf14/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [2]:
lr_list = np.linspace(0.01, 0.001, 4)
for i, lr_val in enumerate(lr_list):
    train_dict[i]['noise_val'] = lr_val

def train_on_batch(train_dict, data, to_swap='lr'):
    x_data, y_data = data
    feed_dict = {inputs: x_data,
                 targets: y_data}
    feed_dict.update({train_dict[i][to_swap]: train_dict[i]['noise_val']})
    train_ops = [train_dict[i]['train_op'] for i in range(n_replicas)]
    losses = [train_dict[i]['loss'] for i in range(n_replicas)]
    errors = [train_dict[i]['error'] for i in range(n_replicas)]
    evaled = sess.run(losses + errors + train_ops, feed_dict=feed_dict)
    losses = evaled[:n_replicas]
    errors = evaled[n_replicas: 2 * n_replicas]
    return losses, errors

def swap_replicas(train_dict, losses, logs, coeff=1., to_swap='lr'):

    temperatures = [train_dict[i]['noise_val'] for i in range(n_replicas)]
    if 'lr' == to_swap:
        betas_and_ids = [(1. / b, i, b) for i, b in enumerate(temperatures)]
    else:
        betas_and_ids = [(t / (1. - t), i, t) for i, b in enumerate(temperatures)]
    betas_and_ids.sort(key=lambda x: x[0])
    
    random_pair = np.random.randint(low=0, high=n_replicas - 1)
    i = betas_and_ids[random_pair][1]
    j = betas_and_ids[random_pair + 1][1]
    temp_i = betas_and_ids[random_pair][2]
    beta_i = betas_and_ids[random_pair][0]
    temp_j = betas_and_ids[random_pair + 1][2]
    beta_j = betas_and_ids[random_pair + 1][0]
    loss_i, loss_j = losses[i], losses[j]
    
    proba = np.exp(coeff * (loss_i - loss_j) * (beta_i - beta_j))
    if np.random.uniform() < proba:
        swap_success = 1
        train_dict[i] = temp_j
        train_dict[j] = temp_i
    else:
        swap_success = 0
    logs['swap_success'] += swap_success
    logs['swap_attempts'] += 1
    


def train(train_dict,
          data,
          batch_size=32,
          epochs=32,
          swap_step=100,
          to_swap='lr'):
    (x_train, y_train, x_test, y_test, x_valid, y_valid) = data
    