# Variational Autoencoder in TensorFlow

In [1]:
import numpy as np
from scipy.stats import mode
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.layers import xavier_initializer
import os.path

slim = tf.contrib.slim

In [2]:
# Import data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [15]:
def log_bernoulli_with_logits(x, logits, eps=0.0, axis=-1):
    if eps > 0.0:
        max_val = np.log(1.0 - eps) - np.log(eps)
        logits = tf.clip_by_value(logits, -max_val, max_val,
                                  name='clipped_logit')
    return -tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=x), axis)

def log_normal(x, mu, var, eps=0.0, axis=-1):
    if eps > 0.0:
        var = tf.add(var, eps, name='clipped_var')
    return -0.5 * tf.reduce_sum(
        tf.log(2 * np.pi) + tf.log(var) + tf.square(x - mu) / var, axis)

In [16]:
def qy_graph(x, k=10):
    """Network q(z|x)"""
    with slim.arg_scope([slim.fully_connected],
                    activation_fn=tf.nn.relu,
                    weights_initializer=xavier_initializer(),
                    biases_initializer=tf.zeros_initializer(),
                    reuse=tf.AUTO_REUSE):
        qy_logit = slim.fully_connected(x, 512, scope='fc1')
        qy_logit = slim.fully_connected(qy_logit, 512, scope='fc2')
        #add sbp_dropout layer in this
        qy_logit = slim.fully_connected(qy_logit, k, activation_fn=None, scope='logit')
        qy = tf.nn.softmax(qy_logit, name='prob')
    return qy_logit, qy

def qz_graph(x, y):
    with slim.arg_scope([slim.fully_connected],
                    activation_fn=tf.nn.relu,
                    weights_initializer=xavier_initializer(),
                    biases_initializer=tf.zeros_initializer(),
                    reuse=tf.AUTO_REUSE):
        mu_logvar = tf.concat([x,y],1, name='xy/concat')
        mu_logvar = slim.fully_connected(mu_logvar, 512, scope='fc3')
        mu_logvar = slim.fully_connected(mu_logvar, 512, scope='fc4')
        mu_logvar = slim.fully_connected(mu_logvar, 128, activation_fn=None, scope='fc5')
        mu, logvar = tf.split(mu_logvar, num_or_size_splits=2, axis=1)
        stddev = tf.sqrt(tf.exp(logvar))

        # Draw a z from the distribution
        epsilon = tf.random_normal(tf.shape(stddev))
        z = mu + tf.multiply(stddev, epsilon)
        return z, mu, logvar

In [17]:
def decoder(z, y):
    with slim.arg_scope([slim.fully_connected],
                        activation_fn=tf.nn.relu,
                        weights_initializer=xavier_initializer(),
                        biases_initializer=tf.zeros_initializer(),
                        reuse=tf.AUTO_REUSE): 
        # ---p(z)
        mu_logvar = slim.fully_connected(y, 128, activation_fn=None, scope='de_1')
        mu, logvar = tf.split(mu_logvar, num_or_size_splits=2, axis=1)
        # ---p(x)
        x_logit = slim.fully_connected(z, 512, scope='defc1')
        x_logit = slim.fully_connected(x_logit, 512, scope='defc2')
        x_logit = slim.fully_connected(x_logit, 784, activation_fn=None, scope='defc3')
        
    return mu, logvar, x_logit

In [18]:
def labeled_loss(x, px_logit, z, zm, zv, zm_prior, zv_prior):
    xy_loss = -log_bernoulli_with_logits(x, px_logit)
    xy_loss += log_normal(z, zm, zv) - log_normal(z, zm_prior, zv_prior)
    return xy_loss - np.log(0.1)

In [19]:
def test_acc(mnist, sess, qy_logit):
    logits = sess.run(qy_logit, feed_dict={'x:0': mnist.test.images})
    cat_pred = logits.argmax(1)
    real_pred = np.zeros_like(cat_pred)
    for cat in range(logits.shape[1]):
        idx = cat_pred == cat
        lab = mnist.test.labels.argmax(1)[idx]
        if len(lab) == 0:
            continue
        real_pred[cat_pred == cat] = mode(lab).mode[0]
    return np.mean(real_pred == mnist.test.labels.argmax(1))

In [20]:
def stream_print(f, string, pipe_to_file=True):
    print(string)
    if pipe_to_file and f is not None:
        f.write(string + '\n')
        f.flush()

In [21]:
def open_file(fname):
    if fname is None:
        return None
    else:
        i = 0
        while os.path.isfile('{:s}.{:d}'.format(fname, i)):
            i += 1
        return open('{:s}.{:d}'.format(fname, i), 'w')

In [22]:
def train(log_file, data, sess_info, epochs):
    (sess, qy_logit, nent, loss, train_step) = sess_info
    f = open_file(log_file)
    iterep = 500
    for i in range(iterep * epochs):
        sess.run(train_step, feed_dict={'x:0': mnist.train.next_batch(100)[0]})
        #message='i={:d}'.format(i + 1)
        #progbar(i, iterep, message)
        if (i + 1) % iterep == 0:
            a, b = sess.run([nent, loss], feed_dict=
                            {'x:0': mnist.train.images[np.random.choice(50000, 10000)]})
            c, d = sess.run([nent, loss], feed_dict={'x:0': mnist.test.images})
            a,b,c,d = -a.mean(), b.mean(), -c.mean(), d.mean()
            e = test_acc(mnist, sess, qy_logit)
            string = ('{:>10s},{:>10s},{:>10s},{:>10s},{:>10s},{:>10s}'
                      .format('tr_ent', 'tr_loss', 't_ent', 't_loss', 't_acc', 'epoch'))
            stream_print(f, string, i <= iterep)
            string = ('{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10d}'
                      .format(a, b, c, d, e, (i + 1) // iterep))
            stream_print(f, string)
    if f is not None: 
        f.close()

In [23]:
tf.reset_default_graph()
x = tf.placeholder(tf.float32,[None, 784], name='x')

In [24]:
with tf.name_scope('x_binarized'):
    xb = tf.cast(tf.greater(x, tf.random_uniform(tf.shape(x), 0, 1)), tf.float32)
with tf.name_scope('y_'):
    y_ = tf.fill(tf.stack([tf.shape(x)[0], 10]), 0.0)

In [25]:
qy_logit, qy = qy_graph(xb)

In [26]:
z, zm, zv, zm_prior, zv_prior, px_logit = [[None] * 10 for i in range(6)]
for i in range(10):
    with tf.name_scope('graphs/hot_at{:d}'.format(i)):
        y = tf.add(y_, tf.constant(np.eye(10)[i], dtype=tf.float32, name='hot_at_{:d}'.format(i)))
        z[i], zm[i], zv[i] = qz_graph(xb, y)
        zm_prior[i], zv_prior[i], px_logit[i] = decoder(z[i], y)

In [27]:
with tf.name_scope('loss'):
    with tf.name_scope('neg_entropy'):
        nent = tf.reduce_sum(qy * tf.nn.log_softmax(qy_logit), 1)
    losses = [None] * 10
    for i in range(10):
        with tf.name_scope('loss_at{:d}'.format(i)):
            losses[i] = labeled_loss(xb, px_logit[i], z[i], zm[i], tf.exp(zv[i]), zm_prior[i], tf.exp(zv_prior[i]))
    with tf.name_scope('final_loss'):
        loss = tf.add_n([nent] + [qy[:, i] * losses[i] for i in range(10)])

In [28]:
train_step = tf.train.AdamOptimizer().minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [30]:
sess_info = (sess, qy_logit, nent, loss, train_step)
train('log/gmvae.log', mnist, sess_info, epochs=1000)

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.13e+00,  1.26e+02,  2.13e+00,  1.24e+02,  3.88e-01,         1
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.09e+00,  1.12e+02,  2.09e+00,  1.11e+02,  3.64e-01,         2
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.97e+00,  1.07e+02,  1.97e+00,  1.06e+02,  3.78e-01,         3
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.93e+00,  1.04e+02,  1.92e+00,  1.03e+02,  3.87e-01,         4
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.81e+00,  1.02e+02,  1.81e+00,  1.02e+02,  3.69e-01,         5
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.81e+00,  1.01e+02,  1.80e+00,  1.01e+02,  3.51e-01,         6
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.74e+00,  9.96e+01,  1.73e+00,  9.97e+01,  3.83e-01,         7
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.70e+00

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.23e-01,  9.23e+01,  2.22e-01,  9.39e+01,  7.49e-01,        64
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.08e-01,  9.24e+01,  2.17e-01,  9.41e+01,  7.64e-01,        65
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.01e-01,  9.27e+01,  2.05e-01,  9.41e+01,  7.57e-01,        66
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.08e-01,  9.22e+01,  2.14e-01,  9.40e+01,  7.69e-01,        67
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.93e-01,  9.25e+01,  2.04e-01,  9.43e+01,  7.67e-01,        68
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.11e-01,  9.23e+01,  2.19e-01,  9.42e+01,  7.70e-01,        69
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.89e-01,  9.21e+01,  1.96e-01,  9.39e+01,  7.67e-01,        70
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.97e-01

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.87e-02,  9.09e+01,  1.03e-01,  9.30e+01,  8.55e-01,       127
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.18e-02,  9.13e+01,  9.61e-02,  9.34e+01,  8.52e-01,       128
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.48e-02,  9.14e+01,  9.70e-02,  9.34e+01,  8.53e-01,       129
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.99e-02,  9.10e+01,  1.02e-01,  9.37e+01,  8.52e-01,       130
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.67e-02,  9.05e+01,  1.02e-01,  9.34e+01,  8.55e-01,       131
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.25e-02,  9.10e+01,  9.75e-02,  9.32e+01,  8.57e-01,       132
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.77e-02,  9.04e+01,  1.04e-01,  9.33e+01,  8.56e-01,       133
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.84e-02

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.32e-02,  8.99e+01,  6.59e-02,  9.29e+01,  8.88e-01,       190
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.31e-02,  9.06e+01,  6.53e-02,  9.31e+01,  8.88e-01,       191
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.66e-02,  9.04e+01,  7.00e-02,  9.32e+01,  8.84e-01,       192
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.45e-02,  9.04e+01,  6.87e-02,  9.29e+01,  8.87e-01,       193
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  5.54e-02,  9.04e+01,  6.00e-02,  9.29e+01,  8.89e-01,       194
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.03e-02,  8.99e+01,  6.43e-02,  9.30e+01,  8.87e-01,       195
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.42e-02,  9.01e+01,  6.66e-02,  9.30e+01,  8.89e-01,       196
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.06e-02

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.85e-02,  8.99e+01,  4.46e-02,  9.28e+01,  9.08e-01,       253
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.72e-02,  8.99e+01,  4.19e-02,  9.29e+01,  9.04e-01,       254
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.57e-02,  8.97e+01,  3.92e-02,  9.28e+01,  9.02e-01,       255
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.72e-02,  8.95e+01,  4.12e-02,  9.27e+01,  9.05e-01,       256
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.75e-02,  8.95e+01,  3.68e-02,  9.27e+01,  9.09e-01,       257
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.89e-02,  8.93e+01,  4.11e-02,  9.26e+01,  9.05e-01,       258
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.62e-02,  8.98e+01,  3.93e-02,  9.30e+01,  9.06e-01,       259
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.46e-02

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.63e-02,  8.94e+01,  2.90e-02,  9.27e+01,  9.14e-01,       316
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.62e-02,  8.96e+01,  2.82e-02,  9.30e+01,  9.22e-01,       317
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.42e-02,  8.94e+01,  3.14e-02,  9.25e+01,  9.21e-01,       318
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.52e-02,  8.94e+01,  2.91e-02,  9.25e+01,  9.17e-01,       319
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.75e-02,  8.96e+01,  2.97e-02,  9.26e+01,  9.20e-01,       320
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.67e-02,  8.92e+01,  3.08e-02,  9.25e+01,  9.20e-01,       321
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.56e-02,  8.93e+01,  2.72e-02,  9.25e+01,  9.20e-01,       322
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.93e-02

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.50e-02,  8.93e+01,  1.79e-02,  9.25e+01,  9.29e-01,       379
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.34e-02,  8.91e+01,  1.84e-02,  9.25e+01,  9.30e-01,       380
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.34e-02,  8.96e+01,  1.63e-02,  9.26e+01,  9.33e-01,       381
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.36e-02,  8.88e+01,  1.68e-02,  9.25e+01,  9.28e-01,       382
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.20e-02,  8.88e+01,  1.59e-02,  9.24e+01,  9.30e-01,       383
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.53e-02,  8.90e+01,  1.75e-02,  9.26e+01,  9.33e-01,       384
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.31e-02,  8.93e+01,  1.71e-02,  9.23e+01,  9.31e-01,       385
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.41e-02

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.52e-03,  8.94e+01,  9.64e-03,  9.25e+01,  9.34e-01,       442
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.35e-03,  8.88e+01,  9.58e-03,  9.23e+01,  9.38e-01,       443
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.92e-03,  8.91e+01,  9.86e-03,  9.26e+01,  9.37e-01,       444
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.40e-03,  8.89e+01,  1.03e-02,  9.24e+01,  9.35e-01,       445
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.31e-03,  8.91e+01,  9.54e-03,  9.24e+01,  9.38e-01,       446
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.52e-03,  8.88e+01,  8.93e-03,  9.22e+01,  9.39e-01,       447
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.87e-03,  8.90e+01,  9.88e-03,  9.25e+01,  9.34e-01,       448
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.41e-03

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  4.26e-03,  8.91e+01,  5.00e-03,  9.24e+01,  9.39e-01,       505
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  4.41e-03,  8.94e+01,  5.91e-03,  9.26e+01,  9.43e-01,       506
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  4.22e-03,  8.86e+01,  6.71e-03,  9.24e+01,  9.41e-01,       507
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.90e-03,  8.87e+01,  5.59e-03,  9.24e+01,  9.38e-01,       508
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  4.44e-03,  8.88e+01,  6.07e-03,  9.24e+01,  9.41e-01,       509
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.64e-03,  8.92e+01,  5.71e-03,  9.24e+01,  9.42e-01,       510
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.87e-03,  8.89e+01,  6.07e-03,  9.23e+01,  9.41e-01,       511
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.68e-03

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.63e-03,  8.91e+01,  3.50e-03,  9.22e+01,  9.39e-01,       568
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.75e-03,  8.91e+01,  3.87e-03,  9.24e+01,  9.38e-01,       569
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.69e-03,  8.88e+01,  3.72e-03,  9.21e+01,  9.38e-01,       570
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.11e-03,  8.88e+01,  3.91e-03,  9.22e+01,  9.42e-01,       571
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.75e-03,  8.90e+01,  3.59e-03,  9.23e+01,  9.40e-01,       572
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.63e-03,  8.86e+01,  3.43e-03,  9.23e+01,  9.42e-01,       573
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.54e-03,  8.89e+01,  3.38e-03,  9.22e+01,  9.42e-01,       574
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  2.31e-03

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.24e-03,  8.85e+01,  2.77e-03,  9.23e+01,  9.44e-01,       631
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.69e-03,  8.90e+01,  2.89e-03,  9.24e+01,  9.43e-01,       632
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.87e-03,  8.86e+01,  2.90e-03,  9.25e+01,  9.43e-01,       633
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.37e-03,  8.84e+01,  2.36e-03,  9.22e+01,  9.44e-01,       634
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.80e-03,  8.89e+01,  2.75e-03,  9.24e+01,  9.44e-01,       635
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.77e-03,  8.90e+01,  3.14e-03,  9.25e+01,  9.46e-01,       636
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.45e-03,  8.89e+01,  2.53e-03,  9.23e+01,  9.47e-01,       637
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.87e-03

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.14e-03,  8.86e+01,  1.47e-03,  9.23e+01,  9.45e-01,       694
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.07e-03,  8.84e+01,  2.15e-03,  9.24e+01,  9.44e-01,       695
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.36e-03,  8.86e+01,  1.56e-03,  9.23e+01,  9.41e-01,       696
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.29e-03,  8.87e+01,  1.75e-03,  9.25e+01,  9.47e-01,       697
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.80e-04,  8.88e+01,  2.15e-03,  9.25e+01,  9.46e-01,       698
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.14e-03,  8.90e+01,  1.53e-03,  9.26e+01,  9.48e-01,       699
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.03e-03,  8.85e+01,  1.82e-03,  9.23e+01,  9.48e-01,       700
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.12e-03

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.15e-03,  8.86e+01,  1.53e-03,  9.23e+01,  9.49e-01,       757
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.11e-03,  8.91e+01,  1.66e-03,  9.24e+01,  9.47e-01,       758
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.22e-03,  8.88e+01,  2.25e-03,  9.24e+01,  9.46e-01,       759
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.01e-03,  8.81e+01,  1.96e-03,  9.21e+01,  9.46e-01,       760
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.26e-03,  8.88e+01,  1.74e-03,  9.25e+01,  9.46e-01,       761
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.05e-03,  8.86e+01,  1.67e-03,  9.24e+01,  9.48e-01,       762
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.45e-04,  8.89e+01,  1.26e-03,  9.24e+01,  9.46e-01,       763
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.31e-03

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.07e-03,  8.83e+01,  1.60e-03,  9.23e+01,  9.49e-01,       820
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.14e-03,  8.88e+01,  1.61e-03,  9.23e+01,  9.48e-01,       821
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  1.11e-03,  8.83e+01,  1.21e-03,  9.22e+01,  9.49e-01,       822
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.00e-04,  8.80e+01,  1.18e-03,  9.23e+01,  9.45e-01,       823
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.03e-04,  8.90e+01,  1.42e-03,  9.24e+01,  9.49e-01,       824
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.35e-04,  8.82e+01,  1.60e-03,  9.23e+01,  9.48e-01,       825
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  5.94e-04,  8.81e+01,  1.41e-03,  9.24e+01,  9.47e-01,       826
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.87e-04

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.81e-04,  8.89e+01,  1.25e-03,  9.25e+01,  9.49e-01,       883
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.32e-04,  8.81e+01,  1.22e-03,  9.21e+01,  9.49e-01,       884
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.53e-04,  8.81e+01,  9.77e-04,  9.22e+01,  9.49e-01,       885
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  3.85e-04,  8.87e+01,  1.41e-03,  9.22e+01,  9.46e-01,       886
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.47e-04,  8.85e+01,  1.42e-03,  9.22e+01,  9.48e-01,       887
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  7.22e-04,  8.84e+01,  9.81e-04,  9.24e+01,  9.48e-01,       888
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.54e-04,  8.82e+01,  9.32e-04,  9.21e+01,  9.49e-01,       889
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  9.21e-04

    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.21e-04,  8.88e+01,  7.99e-04,  9.22e+01,  9.49e-01,       946
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.17e-04,  8.82e+01,  1.12e-03,  9.22e+01,  9.49e-01,       947
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  6.54e-04,  8.84e+01,  6.10e-04,  9.25e+01,  9.48e-01,       948
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  4.68e-04,  8.84e+01,  8.64e-04,  9.22e+01,  9.51e-01,       949
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.30e-04,  8.84e+01,  7.81e-04,  9.23e+01,  9.47e-01,       950
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.12e-04,  8.81e+01,  9.40e-04,  9.22e+01,  9.49e-01,       951
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  8.46e-04,  8.83e+01,  1.14e-03,  9.23e+01,  9.48e-01,       952
    tr_ent,   tr_loss,     t_ent,    t_loss,     t_acc,     epoch
  4.51e-04