In [1]:
import tensorflow as tf
from tensorflow.python.summary.writer.writer import FileWriter
from sklearn.model_selection import train_test_split
import numpy as np

from models import build_simple_conv_net
tf.compat.v1.reset_default_graph()

N_REPLICAS = 6
N_OUTPUTS = 4
LEARNING_RATE = 0.01
BATCH_SIZE = 64
EPOCHS = 10
N_GPUS = 3
train_dict = {}

def gpu_device_name(replica_id):
    if N_GPUS:
        return '/gpu:' + str(replica_id % N_GPUS)
    return '/cpu:0'

with tf.name_scope('inputs'):
    inputs = tf.keras.layers.Input(shape=(28, 28, 1), dtype=tf.float32)

with tf.name_scope('targets'):
    targets = tf.keras.layers.Input(shape=(), dtype=tf.int32)

for i in range(N_REPLICAS):
    dropout_rate = tf.compat.v1.placeholder_with_default(0., shape=())

    train_dict[i] = {'dropout_rate': dropout_rate}
    with tf.device(gpu_device_name(i)):
        with tf.name_scope('model_' + str(i)):
            train_dict[i]['model'] = build_simple_conv_net(inputs,
                                                           dropout_rate,
                                                           n_outputs=N_OUTPUTS)
        with tf.name_scope('loss_' + str(i)):
            xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=targets, logits=train_dict[i]['model'].outputs[0])
            train_dict[i]['loss'] = tf.reduce_mean(xentropy)

        with tf.name_scope('optimizer_' + str(i)):
            grads = tf.gradients(train_dict[i]['loss'],
                                 train_dict[i]['model'].trainable_variables)
            grads_and_vars = zip(grads, train_dict[i]['model'].trainable_variables)

            train_ops = [w.assign(w - LEARNING_RATE * g) for g, w in grads_and_vars]
            train_dict[i]['train_op'] = tf.group(train_ops)

        with tf.name_scope('error_' + str(i)):
            y_pred = tf.argmax(train_dict[i]['model'].outputs[0], axis=1)
            y_pred = tf.cast(y_pred, tf.int32)
            y_true = tf.cast(targets, tf.int32)
            equals = tf.cast(tf.math.equal(y_pred, y_true), tf.float32)
            train_dict[i]['error'] = 1. - tf.reduce_mean(equals)


FileWriter('logs/train', graph=inputs.graph).close()


W1106 17:48:09.891010 4548449728 deprecation.py:506] From /Users/vpushkarov/anaconda3/envs/tf14/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [2]:
dropout_list = np.linspace(0., 0.5, N_REPLICAS)
for i, val in enumerate(dropout_list):
    train_dict[i]['current_temperature'] = val

    
def create_empty_logs_dict():
    logs = {'error_' + str(i): [] for i in range(N_REPLICAS)}
    logs.update({'loss_' + str(i): [] for i in range(N_REPLICAS)})
    logs.update({'temperature_' + str(i): [] for i in range(N_REPLICAS)})
    logs['swap_success'] = 0
    logs['swap_attempts'] = 0
    return logs

def append_to_log_dict(from_dict, to_dict):
    
    for metric, values in from_dict.items():
        if not isinstance(values, list):
            to_dict[metric] += values
        elif 'temperature_' not in metric:
            to_dict[metric].append(np.mean(values))
        else:
            to_dict[metric] += values
    return to_dict
    
def iteritems(items, batch_size):
    start_idx = 0
    while start_idx < items.shape[0]:
        yield items[start_idx: start_idx + batch_size]
        start_idx += batch_size



def dataset_splits():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train[..., None] / 255., x_test[..., None] / 255.
    train_indices = np.argwhere(y_train < N_OUTPUTS).squeeze()
    test_indices = np.argwhere(y_test < N_OUTPUTS).squeeze()
    x_train, y_train = x_train[train_indices], y_train[train_indices]

    x_test, y_test = x_test[test_indices], y_test[test_indices]
    x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train)
    return (x_train, y_train), (x_test, y_test), (x_valid, y_valid)
(x_train, y_train), (x_test, y_test), (x_valid, y_valid) = dataset_splits()

datasets_dict = {
    'x_train': x_train,
    'y_train': y_train,
    'x_test': x_test,
    'y_test': y_test,
    'x_valid': x_valid,
    'y_valid': y_valid
}

print('train data size:', x_train.shape[0])
print('test data size:', x_test.shape[0])
print('validation data size', x_valid.shape[0])

    


train data size: 18565
test data size: 4157
validation data size 6189


In [3]:
def maybe_swap_replicas(train_dict, logs, coeff=1000.):
    def _compute_losses():
        data = datasets_dict['x_valid'], datasets_dict['y_valid']
        logs = evaluate_on_epoch(train_dict, data)
        losses = [logs['loss_' + str(i)] for i in range(N_REPLICAS)]
        return losses
    
    losses = _compute_losses()
    temperatures = [train_dict[i]['current_temperature'] for i in range(N_REPLICAS)]
    
    betas_ids_temperatures = [((1. - t) / t, review_id, t)
                              for review_id, t in enumerate(temperatures)]
    betas_ids_temperatures.sort(key=lambda x: x[0])
    
    picked_pair = np.random.randint(low=0, high=N_REPLICAS - 1)

    replica_i = betas_ids_temperatures[picked_pair][1]
    replica_j = betas_ids_temperatures[picked_pair + 1][1]

    temperature_i = betas_ids_temperatures[picked_pair][2]
    temperature_j = betas_ids_temperatures[picked_pair + 1][2]
    
    beta_i = betas_ids_temperatures[picked_pair][0]
    beta_j = betas_ids_temperatures[picked_pair + 1][0]
    
    loss_i = losses[replica_i], 
    loss_j = losses[replica_j]

    proba = np.exp(coeff * (loss_i - loss_j) * (beta_i - beta_j))
    if np.random.uniform() < proba:
        swap_success = 1
        train_dict[replica_i]['current_temperature'] = temperature_j
        train_dict[replica_j]['current_temperature'] = temperature_i
    else:
        swap_success = 0
    logs['swap_success'] += swap_success
    logs['swap_attempts'] += 1
    logs['validation']
    return logs

In [4]:
def train_on_batch(train_dict, data):

    x_data, y_data = data
    feed_dict = {inputs: x_data,
                 targets: y_data}
    feed_dict.update({train_dict[i]['dropout_rate']: train_dict[i]['current_temperature']
                      for i in range(N_REPLICAS)})
    train_ops = [train_dict[i]['train_op'] for i in range(N_REPLICAS)]
    losses = [train_dict[i]['loss'] for i in range(N_REPLICAS)]
    errors = [train_dict[i]['error'] for i in range(N_REPLICAS)]
    sess = tf.compat.v1.get_default_session()
    evaled = sess.run(losses + errors + train_ops, feed_dict=feed_dict)
    losses = evaled[:N_REPLICAS]
    errors = evaled[N_REPLICAS: 2 * N_REPLICAS]

    return losses, errors

def evaluate_on_batch(train_dict, data):
    x_data, y_data = data
    feed_dict = {inputs: x_data,
                 targets: y_data}
    feed_dict.update({train_dict[i]['dropout_rate']: train_dict[i]['current_temperature']
                      for i in range(N_REPLICAS)})
    losses = [train_dict[i]['loss'] for i in range(N_REPLICAS)]
    errors = [train_dict[i]['error'] for i in range(N_REPLICAS)]
    sess = tf.compat.v1.get_default_session()
    evaled = sess.run(losses + errors, feed_dict=feed_dict)
    losses = evaled[:N_REPLICAS]
    errors = evaled[N_REPLICAS: 2 * N_REPLICAS]
    return losses, errors

def train_on_epoch(train_dict, data, step, swap_step):
    (x_train, y_train) = datasets_dict['x_train'], datasets_dict['y_train']
    
    zipped_batches = zip(iteritems(x_train, BATCH_SIZE),
                         iteritems(y_train, BATCH_SIZE))
    logs = create_empty_logs_dict()

    for (x_batch, y_batch) in zipped_batches:
        losses, errors = train_on_batch(train_dict,
                                        (x_batch, y_batch))

        for i in range(N_REPLICAS):
            logs['error_' + str(i)].append(errors[i])
        for i in range(N_REPLICAS):
            logs['loss_' + str(i)].append(losses[i])
        for i in range(N_REPLICAS):
            logs['temperature_' + str(i)].append(train_dict[i]['current_temperature'])
        
        step += 1
        if step % swap_step == 0:
            logs = maybe_swap_replicas(train_dict, logs)
    return logs, step

def evaluate_on_epoch(train_dict, data):
    (x_train, y_train) = data
    
    zipped_batches = zip(iteritems(x_train, BATCH_SIZE),
                         iteritems(y_train, BATCH_SIZE))
    logs = create_empty_logs_dict()
    for (x_batch, y_batch) in zipped_batches:
        losses, errors = train_on_batch(train_dict,
                                        (x_batch, y_batch))
        for i in range(N_REPLICAS):
            logs['error_' + str(i)].append(errors[i])
        for i in range(N_REPLICAS):
            logs['loss_' + str(i)].append(losses[i])

    for metric, values in logs.items():
        if isinstance(values, list):
            logs[metric] = np.mean(values)

    return logs

In [5]:
def train(train_dict,
          data,
          swap_step=100):
    (x_train, y_train, x_test, y_test, x_valid, y_valid) = data

    current_step = 0
    all_logs = {
        'train': create_empty_logs_dict(),
        'test': create_empty_logs_dict(),
        'validation': create_empty_logs_dict()
    }

    graph = train_dict[0]['model'].inputs[0].graph
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(graph=graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(EPOCHS):
            logs, step = train_on_epoch(train_dict,
                                        (x_train, y_train),
                                        current_step,
                                        swap_step)
            all_logs['train'] = append_to_log_dict(logs, all_logs['train'])

            logs = evaluate_on_epoch(train_dict,
                                     (x_test, y_test))
            all_logs['test'] = append_to_log_dict(logs, all_logs['test'])

            logs = evaluate_on_epoch(train_dict,
                                     (x_valid, y_valid))
            all_logs['validation'] = append_to_log_dict(logs, all_logs['validation'])

            current_step += step
    return all_logs

(x_train, y_train), (x_test, y_test), (x_valid, y_valid) = dataset_splits()

logs = train(train_dict,
             (x_train, y_train, x_test, y_test, x_valid, y_valid),
             swap_step=100)


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  if sys.path[0] == '':


KeyError: 'validation'