In [1]:
import numpy as np
import tensorflow as tf
from tqdm import trange 
import logging
import os
import pandas as pd
%run "./utils.py"

params = {
    "max_query_words": 12,
    "max_passage_words": 50,
    "emb_dim": 50,
    "BATCH_SIZE": 500,
    "EPOCHS" : 200,  #Total number of epochs to run
    "BatchesPerEpoch": 1000,
    "num_classes": 2,
    "save_summary_steps": 100,
    "TEST_BATCH_SIZE": 1000,
    "SHUFFLE_BATCH_SIZE": 10000
}

  from ._conv import register_converters as _register_converters


In [2]:
def cnn_network(queryfeatures, passagefeatures,reuse = False):
    
    #queryfeatures = tf.placeholder(tf.float32,shape=(None,max_query_words * emb_dim))
    #passagefeatures = tf.placeholder(tf.float32,shape=(None,max_passage_words * emb_dim))
    
    #global max_passage_words,max_query_words,emb_dim,num_classes

    def conv2D(x,W,strides):
        return tf.nn.conv2d(x,W,strides=strides,padding='VALID',name="conv2D")

    def maxPooling(x,k):
        return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],padding='VALID',name = "maxPool")

    def getWeight(shape):
        return tf.get_variable(name = "weight",shape = shape, initializer=tf.initializers.truncated_normal(stddev= 0.1))

    def getBias(shape):
        return tf.get_variable(name = "bias",shape = shape, initializer=tf.initializers.constant(0.1))

    with tf.variable_scope('query/convLayer1'):
        #------Filter: [filter_height, filter_width, in_channels, out_channels]
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([3,10,1,4])
            bias = getBias([4])
        #------Input: [batch, in_height, in_width, in_channels]
        #x_query = tf.reshape(queryfeatures,[-1,max_query_words,emb_dim,1], name="ReshapeInputOp")  #input: ?,12,50,1
        #------Stride [batch,height,width,channel]
        convQuery1 = tf.nn.relu(conv2D(queryfeatures,weight,[1,1,1,1]) + bias)
        convQuery1_max = maxPooling(convQuery1,k = 2)

    with tf.variable_scope('query/convLayer2'):
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([2,4,4,2])
            bias = getBias([2])
        convQuery2 = tf.nn.relu(conv2D(convQuery1_max,weight,[1,1,1,1]) + bias)
        convQuery2_max = maxPooling(convQuery2,k = 2)
        print(convQuery2_max)

    with tf.variable_scope('query/denseLayer3'):
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([2*8*2,10])
            bias = getBias([10])
        dense = tf.reshape(convQuery2_max,[-1,2*8*2])
        denseQuery3 = tf.nn.relu(tf.matmul(dense,weight) + bias)

    with tf.variable_scope('passage/convLayer1'):
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([5,10,1,4])
            bias = getBias([4])
        #x_passage = tf.reshape(passagefeatures,[-1,max_passage_words,emb_dim,1])
        convPassage1 = tf.nn.relu(conv2D(passagefeatures,weight,[1,1,1,1]) + bias)
        convPassage1_max = maxPooling(convPassage1,k = 2)


    with tf.variable_scope('passage/convLayer2'):
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([3,10,4,4])
            bias = getBias([4])
        convPassage2 = tf.nn.relu(conv2D(convPassage1_max,weight,[1,1,1,1]) + bias)
        convPassage2_max = maxPooling(convPassage2,k = 2)

    with tf.variable_scope('passage/denseLayer3'):
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([10*5*4,10])
            bias = getBias([10])
        densePassage = tf.reshape(convPassage2_max,[-1,10*5*4])
        densePassage3 = tf.nn.relu(tf.matmul(densePassage,weight) + bias)

    with tf.variable_scope('mergeQueryPassage'):
        with tf.variable_scope('shared',reuse=reuse):
            weight = getWeight([10,2])
            bias = getBias([2])
        
        mergeOp = tf.multiply(denseQuery3,densePassage3, name = "merge")
        mergeDense = tf.nn.relu(tf.matmul(mergeOp,weight) + bias)

    return mergeDense
    

In [3]:
def modelTest_fn(mode,embeddingsFile,params):    
    #--------Hyper parameters:
    max_query_words = params["max_query_words"]
    max_passage_words = params["max_passage_words"]
    emb_dim = params["emb_dim"]
    BATCH_SIZE = params["TEST_BATCH_SIZE"]
    EPOCHS = params["EPOCHS"]
    BatchesPerEpoch = params["BatchesPerEpoch"]
    num_classes = params["num_classes"]
    
    is_training = (mode == "train")
    
    def testDSParser(example_proto):
        features = {"query": tf.FixedLenFeature((max_query_words,emb_dim,1), tf.float32),
                  "passage": tf.FixedLenFeature((max_passage_words,emb_dim,1), tf.float32),
                  "query_id": tf.FixedLenFeature((1), tf.int64),
                  "passage_id": tf.FixedLenFeature((1), tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, features)
        return parsed_features["query"], parsed_features["passage"],parsed_features["query_id"],parsed_features["passage_id"]
    
    def getDatasetIterator(fileName,batch_size,mode):
        dataset = tf.data.TFRecordDataset(filenames = fileName, compression_type="ZLIB")
        #------Follow this order: map -> prefetch -> batch
        dataset = dataset.map(testDSParser)
        dataset = dataset.batch(batch_size)
        
        iterator = dataset.make_initializable_iterator()
        return iterator
    
    iterator = getDatasetIterator(embeddingsFile,BATCH_SIZE,mode)
    
    queryfeatures,passagefeatures,query_id,passage_id = iterator.get_next()
    
    model_spec =     {
        'queryfeatures': queryfeatures,
        'passagefeatures': passagefeatures,
        'iterator_init_op': iterator.initializer,
        "query_id":query_id,
        "passage_id": passage_id
    }
    
    with tf.variable_scope('model'):
        y_conv = cnn_network(queryfeatures,passagefeatures,reuse = False)
    
    # -----------------------------------------------------------
    # MODEL SPECIFICATION
    # Create the model specification and return it
    # It contains nodes or operations in the graph that will be used for training and evaluation
    variable_init_op = tf.group(*[tf.global_variables_initializer()])
    model_spec['variable_init_op'] = variable_init_op
    model_spec["predictions"] = y_conv
    model_spec['summary_op'] = tf.summary.merge_all()

    
    return model_spec

In [12]:
def model_fn(mode,embeddingsFile,params):    
    #--------Hyper parameters:
    max_query_words = params["max_query_words"]
    max_passage_words = params["max_passage_words"]
    emb_dim = params["emb_dim"]
    BATCH_SIZE = params["BATCH_SIZE"]
    EPOCHS = params["EPOCHS"]
    BatchesPerEpoch = params["BatchesPerEpoch"]
    num_classes = params["num_classes"]
    
    is_training = (mode == "train")
    
    def parser(example_proto):
        features = {"query": tf.FixedLenFeature((max_query_words,emb_dim,1), tf.float32),
                  "passage": tf.FixedLenFeature((max_passage_words,emb_dim,1), tf.float32),
                  "label": tf.FixedLenFeature((num_classes), tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, features)
        return parsed_features["query"], parsed_features["passage"],parsed_features["label"]
    
    def getDatasetIterator(fileName,batch_size,mode):
        buffer_size = 4 * batch_size
        dataset = tf.data.TFRecordDataset(filenames = fileName, compression_type="ZLIB").shuffle(buffer_size = buffer_size)
        #------Follow this order: map -> prefetch -> batch
        dataset = dataset.map(parser)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(1)
        iterator = dataset.make_initializable_iterator()
        return iterator
    
    #queryfeatures = tf.placeholder(tf.float32,shape=(None,max_query_words, emb_dim,1))
    #passagefeatures = tf.placeholder(tf.float32,shape=(None,max_passage_words,emb_dim,1))
    #y = tf.placeholder(tf.int16,shape = (None,num_classes))
    
    iterator = getDatasetIterator(embeddingsFile,BATCH_SIZE,mode)
    
    queryfeatures,passagefeatures,y = iterator.get_next()
    #print(queryfeatures)
    
    model_spec =     {
        'queryfeatures': queryfeatures,
        'passagefeatures': passagefeatures,
        'iterator_init_op': iterator.initializer,
        "y":y
    }
    
    with tf.variable_scope('model'):
        y_conv = cnn_network(queryfeatures,passagefeatures,reuse = not is_training)
    
    with tf.variable_scope('lossPerBatch'):
        cross_entropy = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=y,logits=y_conv))
        tf.summary.scalar('lossPerBatch', cross_entropy)

    if is_training:
        with tf.name_scope('AdamOptim'):
            global_step = tf.train.get_or_create_global_step()
            train_step = tf.train.AdamOptimizer(1e-4).minimize(loss = cross_entropy,global_step=global_step)

    correct_pred = tf.equal(tf.argmax(y_conv,1),tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
    tf.summary.scalar('accuracyPerBatch', accuracy)
    
    
    # -----------------------------------------------------------
    # METRICS AND SUMMARIES
    # Metrics for evaluation using tf.metrics (average over whole dataset)
    with tf.variable_scope("metrics"):
        metrics = {
            'accuracy': tf.metrics.accuracy(labels=tf.argmax(y,-1), predictions=tf.argmax(y_conv,-1)),
            'loss': tf.metrics.mean(cross_entropy)
        }
    
    # Group the update ops for the tf.metrics
    update_metrics_op = tf.group(*[op for _, op in metrics.values()])

    # Get the op to reset the local variables used in tf.metrics
    metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics")
    metrics_init_op = tf.variables_initializer(metric_variables)
    
    
    # -----------------------------------------------------------
    # MODEL SPECIFICATION
    # Create the model specification and return it
    # It contains nodes or operations in the graph that will be used for training and evaluation
    variable_init_op = tf.group(*[tf.global_variables_initializer()])
    model_spec['variable_init_op'] = variable_init_op
    model_spec["predictions"] = y_conv
    model_spec['loss'] = cross_entropy
    model_spec['accuracy'] = accuracy
    model_spec['metrics_init_op'] = metrics_init_op
    model_spec['metrics'] = metrics
    model_spec['update_metrics'] = update_metrics_op
    model_spec['summary_op'] = tf.summary.merge_all()

    if is_training:
        model_spec['train_op'] = train_step
    
    return model_spec

In [19]:
def train_sess(sess, model_spec, num_steps, writer, params):
    """Train the model on `num_steps` batches

    Args:
        sess: (tf.Session) current session
        model_spec: (dict) contains the graph operations or nodes needed for training
        num_steps: (int) train for this number of batches
        writer: (tf.summary.FileWriter) writer for summaries
        params: (Params) hyperparameters
    """
    # Get relevant graph operations or nodes needed for training
    loss = model_spec['loss']
    train_op = model_spec['train_op']
    update_metrics = model_spec['update_metrics']
    metrics = model_spec['metrics']
    summary_op = model_spec['summary_op']
    y_conv = model_spec["predictions"]
    global_step = tf.train.get_global_step()

    # Load the training dataset into the pipeline and initialize the metrics local variables
    sess.run(model_spec['iterator_init_op'])
    sess.run(model_spec['metrics_init_op'])

    # Use tqdm for progress bar
    t = trange(num_steps)
    for i in t:
        # Evaluate summaries for tensorboard only once in a while
        if i % params["save_summary_steps"] == 0:
            # Perform a mini-batch update
            _, _, loss_val, summ, global_step_val = sess.run([train_op, update_metrics, loss,summary_op, global_step])
            # Write summaries for tensorboard
            print("Global Step: ",global_step_val)
            writer.add_summary(summ, global_step_val)
        else:
            _, _, loss_val = sess.run([train_op, update_metrics, loss])
        # Log the loss in the tqdm progress bar
        t.set_postfix(loss='{:05.3f}'.format(loss_val))
        print("Predictions",sess.run([y_conv[0:10,]]))

    metrics_values = {k: v[0] for k, v in metrics.items()}
    metrics_val = sess.run(metrics_values)
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items())
    logging.info("- Train metrics: " + metrics_string)

In [6]:
def evaluate_sess(sess, model_spec, num_steps, writer=None, params=None):
    """Train the model on `num_steps` batches.

    Args:
        sess: (tf.Session) current session
        model_spec: (dict) contains the graph operations or nodes needed for training
        num_steps: (int) train for this number of batches
        writer: (tf.summary.FileWriter) writer for summaries. Is None if we don't log anything
        params: (Params) hyperparameters
    """
    update_metrics = model_spec['update_metrics']
    eval_metrics = model_spec['metrics']
    global_step = tf.train.get_global_step()

    # Load the evaluation dataset into the pipeline and initialize the metrics init op
    sess.run(model_spec['iterator_init_op'])
    sess.run(model_spec['metrics_init_op'])

    # compute metrics over the dataset
    for _ in range(num_steps):
        sess.run(update_metrics)

    # Get the values of the metrics
    metrics_values = {k: v[0] for k, v in eval_metrics.items()}
    metrics_val = sess.run(metrics_values)
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items())
    logging.info("- Eval metrics: " + metrics_string)

    # Add summaries manually to writer at global_step_val
    if writer is not None:
        global_step_val = sess.run(global_step)
        for tag, val in metrics_val.items():
            summ = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=val)])
            writer.add_summary(summ, global_step_val)

    return metrics_val

In [7]:
def evaluate(model_spec, model_dir, params, restore_from):
    import pandas as pd
    """Evaluate the model

    Args:
        model_spec: (dict) contains the graph operations or nodes needed for evaluation
        model_dir: (string) directory containing config, weights and log
        params: (Params) contains hyperparameters of the model.
                Must define: num_epochs, train_size, batch_size, eval_size, save_summary_steps
        restore_from: (string) directory or file containing weights to restore the graph
    """
    # Initialize tf.Saver
    saver = tf.train.Saver()
    
    df = pd.DataFrame()
    with tf.Session() as sess:
        # Initialize the lookup table
        #sess.run(model_spec['variable_init_op'])
        # Reload weights from the weights subdirectory
        save_path = os.path.join(model_dir, restore_from)
        if os.path.isdir(save_path):
            save_path = tf.train.latest_checkpoint(save_path)
        saver.restore(sess, save_path)
        
        sess.run(model_spec['iterator_init_op'])
        totalBatches = (104170//params["TEST_BATCH_SIZE"]) + 1
        
        for index in range(totalBatches):        
            predictions,query_id,passage_id = sess.run([model_spec["predictions"],model_spec["query_id"],model_spec["passage_id"]])
            
            tmp = pd.DataFrame({"query_id":query_id[:,0],"passage_id":passage_id[:,0],"predictions":predictions[:,1]})
            df = pd.concat([df,tmp],axis = 0)
            #print(query_id[0:100])
    
    return df

In [8]:
def train_and_evaluate(train_model_spec, eval_model_spec, model_dir, params):
    """Train the model and evaluate every epoch.

    Args:
        train_model_spec: (dict) contains the graph operations or nodes needed for training
        eval_model_spec: (dict) contains the graph operations or nodes needed for evaluation
        model_dir: (string) directory containing config, weights and log
        params: (Params) contains hyperparameters of the model.
                Must define: num_epochs, train_size, batch_size, eval_size, save_summary_steps
        restore_from: (string) directory or file containing weights to restore the graph
    """
    
    max_query_words = params["max_query_words"]
    max_passage_words = params["max_passage_words"]
    emb_dim = params["emb_dim"]
    BATCH_SIZE = params["BATCH_SIZE"]
    EPOCHS = params["EPOCHS"]
    BatchesPerEpoch = params["BatchesPerEpoch"]
    num_classes = params["num_classes"]
    
    # Initialize tf.Saver instances to save weights during training
    last_saver = tf.train.Saver() # will keep last 5 epochs
    best_saver = tf.train.Saver(max_to_keep=1)  # only keep 1 best checkpoint (best on eval)
    begin_at_epoch = 0
    
    with tf.Session() as sess:
        # Initialize model variables
        sess.run(train_model_spec['variable_init_op'])

        # For tensorboard (takes care of writing summaries to files)
        train_writer = tf.summary.FileWriter(os.path.join(model_dir, 'train_summaries'), sess.graph)
        eval_writer = tf.summary.FileWriter(os.path.join(model_dir, 'eval_summaries'), sess.graph)

        best_eval_acc = 0.0
        for epoch in range(EPOCHS):
            # Run one epoch
            logging.info("Epoch {}/{}".format(epoch + 1, EPOCHS))
            
            train_sess(sess, train_model_spec, BatchesPerEpoch, train_writer, params)

            # Save weights
            last_save_path = os.path.join(model_dir, 'last_weights', 'after-epoch')
            last_saver.save(sess, last_save_path, global_step=epoch + 1)

            metrics = evaluate_sess(sess, eval_model_spec, 1, eval_writer)

            # If best_eval, best_save_path
            eval_acc = metrics['accuracy']
            if eval_acc >= best_eval_acc:
                # Store new best accuracy
                best_eval_acc = eval_acc
                # Save weights
                best_save_path = os.path.join(model_dir, 'best_weights', 'after-epoch')
                best_save_path = best_saver.save(sess, best_save_path, global_step=epoch + 1)
                logging.info("- Found new best accuracy, saving in {}".format(best_save_path))
                # Save best eval metrics in a json file in the model directory
                best_json_path = os.path.join(model_dir, "metrics_eval_best_weights.json")
                save_dict_to_json(metrics, best_json_path)

            # Save latest eval metrics in a json file in the model directory
            last_json_path = os.path.join(model_dir, "metrics_eval_last_weights.json")
            save_dict_to_json(metrics, last_json_path)

In [20]:
#-----------------Main function for training

tf.reset_default_graph()
train_model_spec = model_fn("train","./trainEmbeddings.tfrecords",params)
eval_model_spec = model_fn("eval","./validationEmbeddings.tfrecords",params)
logging.info("Starting training for {} epoch(s)".format(params["EPOCHS"]))
train_and_evaluate(train_model_spec, eval_model_spec, "./ModelLogs", params)

Tensor("model/query/convLayer2/maxPool:0", shape=(?, 2, 8, 2), dtype=float32)
Tensor("model_1/query/convLayer2/maxPool:0", shape=(?, 2, 8, 2), dtype=float32)







  0%|                                                                                            | 0/1000 [00:00<?, ?it/s]

Global Step:  0







  0%|                                                                                | 0/1000 [00:00<?, ?it/s, loss=0.671]

Predictions [array([[0.15584895, 0.08165452],
       [0.15936944, 0.08062461],
       [0.14658819, 0.10190961],
       [0.14653747, 0.09895856],
       [0.16198692, 0.10846816],
       [0.1523965 , 0.10335493],
       [0.17963877, 0.12931609],
       [0.15435189, 0.08586035],
       [0.15851013, 0.08492937],
       [0.13694662, 0.0864806 ]], dtype=float32)]







  0%|                                                                        | 1/1000 [00:00<11:14,  1.48it/s, loss=0.671]




  0%|                                                                        | 1/1000 [00:00<11:14,  1.48it/s, loss=0.672]

Predictions [array([[0.15600099, 0.09677782],
       [0.1474099 , 0.10524456],
       [0.15525709, 0.10395895],
       [0.14389274, 0.10041074],
       [0.14651003, 0.10773011],
       [0.16337986, 0.10982665],
       [0.14586236, 0.1106291 ],
       [0.1532989 , 0.11444686],
       [0.13993207, 0.09834953],
       [0.1491893 , 0.08395373]], dtype=float32)]







  0%|▏                                                                       | 2/1000 [00:01<09:38,  1.73it/s, loss=0.672]




  0%|▏                                                                       | 2/1000 [00:01<09:38,  1.73it/s, loss=0.672]

Predictions [array([[0.14791816, 0.1078626 ],
       [0.15716034, 0.073115  ],
       [0.14750873, 0.08798581],
       [0.1509644 , 0.08523634],
       [0.14832833, 0.09813641],
       [0.16134638, 0.11041585],
       [0.15059066, 0.09807662],
       [0.12994878, 0.10044149],
       [0.15060467, 0.07686725],
       [0.15486771, 0.09723409]], dtype=float32)]







  0%|▏                                                                       | 3/1000 [00:01<07:56,  2.09it/s, loss=0.672]




  0%|▏                                                                       | 3/1000 [00:01<07:56,  2.09it/s, loss=0.671]

Predictions [array([[0.18012866, 0.06728661],
       [0.1527764 , 0.08019216],
       [0.1501334 , 0.09108946],
       [0.1358874 , 0.10165055],
       [0.17506629, 0.08878846],
       [0.14089076, 0.07801491],
       [0.14580289, 0.09564903],
       [0.1515643 , 0.08827799],
       [0.16764686, 0.08066549],
       [0.1372332 , 0.11001568]], dtype=float32)]







  0%|▎                                                                       | 4/1000 [00:01<06:45,  2.46it/s, loss=0.671]




  0%|▎                                                                       | 4/1000 [00:01<06:45,  2.46it/s, loss=0.669]

Predictions [array([[0.1507433 , 0.09722503],
       [0.19314983, 0.0790092 ],
       [0.1328191 , 0.10159724],
       [0.14437705, 0.10994571],
       [0.15085524, 0.08930016],
       [0.1469161 , 0.10963883],
       [0.14210692, 0.10173015],
       [0.15942813, 0.08088921],
       [0.15704474, 0.12655531],
       [0.14831825, 0.09567196]], dtype=float32)]







  0%|▎                                                                       | 5/1000 [00:01<05:53,  2.81it/s, loss=0.669]




  0%|▎                                                                       | 5/1000 [00:01<05:53,  2.81it/s, loss=0.670]

Predictions [array([[0.17079948, 0.07054257],
       [0.16511372, 0.09811972],
       [0.16178843, 0.08111774],
       [0.1453031 , 0.09072288],
       [0.16648988, 0.10842337],
       [0.13489386, 0.09201527],
       [0.16857183, 0.09032601],
       [0.16972789, 0.07795252],
       [0.18231405, 0.09801817],
       [0.16818503, 0.08949476]], dtype=float32)]







  1%|▍                                                                       | 6/1000 [00:01<05:20,  3.11it/s, loss=0.670]




  1%|▍                                                                       | 6/1000 [00:02<05:20,  3.11it/s, loss=0.667]

Predictions [array([[0.13312566, 0.10365091],
       [0.14456522, 0.08635337],
       [0.12781087, 0.11441205],
       [0.1445146 , 0.10141145],
       [0.14622894, 0.10107593],
       [0.18708709, 0.07598588],
       [0.15447983, 0.08618411],
       [0.144087  , 0.11043936],
       [0.17747214, 0.06692062],
       [0.15559258, 0.10201465]], dtype=float32)]







  1%|▌                                                                       | 7/1000 [00:02<04:56,  3.35it/s, loss=0.667]




  1%|▌                                                                       | 7/1000 [00:02<04:56,  3.35it/s, loss=0.668]

Predictions [array([[0.17175409, 0.0826891 ],
       [0.17651159, 0.06717139],
       [0.14558436, 0.09998165],
       [0.17239329, 0.03920001],
       [0.14102139, 0.0876037 ],
       [0.17568505, 0.09041683],
       [0.16168651, 0.06468926],
       [0.16110831, 0.09693792],
       [0.14782444, 0.08898228],
       [0.1627143 , 0.07499076]], dtype=float32)]







  1%|▌                                                                       | 8/1000 [00:02<04:37,  3.57it/s, loss=0.668]




  1%|▌                                                                       | 8/1000 [00:02<04:37,  3.57it/s, loss=0.666]

Predictions [array([[0.1672619 , 0.05907745],
       [0.1727514 , 0.07443193],
       [0.15752336, 0.08662499],
       [0.14499196, 0.09563858],
       [0.16712148, 0.06693304],
       [0.16073158, 0.07897437],
       [0.14777538, 0.09552154],
       [0.13393685, 0.09713801],
       [0.14969781, 0.08858058],
       [0.14263704, 0.09668446]], dtype=float32)]







  1%|▋                                                                       | 9/1000 [00:02<04:25,  3.73it/s, loss=0.666]




  1%|▋                                                                       | 9/1000 [00:02<04:25,  3.73it/s, loss=0.666]

Predictions [array([[0.15222779, 0.08716072],
       [0.15383726, 0.08310167],
       [0.13759787, 0.10210659],
       [0.14479674, 0.08569366],
       [0.13807982, 0.09786367],
       [0.15600383, 0.09167369],
       [0.15580104, 0.07800698],
       [0.14925322, 0.09738678],
       [0.1533069 , 0.09588999],
       [0.15778482, 0.08294959]], dtype=float32)]







  1%|▋                                                                      | 10/1000 [00:02<04:15,  3.87it/s, loss=0.666]




  1%|▋                                                                      | 10/1000 [00:03<04:15,  3.87it/s, loss=0.665]

Predictions [array([[0.1582675 , 0.07067578],
       [0.13785915, 0.09465964],
       [0.16361327, 0.08965725],
       [0.16067511, 0.09846386],
       [0.1650528 , 0.06992573],
       [0.18348956, 0.07144377],
       [0.16112   , 0.08600493],
       [0.1478455 , 0.10737809],
       [0.16276792, 0.08939245],
       [0.16641332, 0.0687777 ]], dtype=float32)]







  1%|▊                                                                      | 11/1000 [00:03<04:10,  3.94it/s, loss=0.665]




  1%|▊                                                                      | 11/1000 [00:03<04:10,  3.94it/s, loss=0.664]

Predictions [array([[0.18866219, 0.10560255],
       [0.15108551, 0.08917923],
       [0.15205827, 0.07276332],
       [0.1635969 , 0.08082429],
       [0.14039022, 0.08697218],
       [0.14967835, 0.08188566],
       [0.1627855 , 0.08008137],
       [0.14375801, 0.07828249],
       [0.16106309, 0.08509877],
       [0.17038536, 0.06646393]], dtype=float32)]







  1%|▊                                                                      | 12/1000 [00:03<04:04,  4.05it/s, loss=0.664]




  1%|▊                                                                      | 12/1000 [00:03<04:04,  4.05it/s, loss=0.664]

Predictions [array([[0.189556  , 0.07863694],
       [0.16474766, 0.07742608],
       [0.1483771 , 0.07958992],
       [0.163244  , 0.09504572],
       [0.13980186, 0.07771353],
       [0.17441009, 0.08381106],
       [0.16631112, 0.05845559],
       [0.14228848, 0.08928844],
       [0.16055916, 0.08545348],
       [0.16061112, 0.07122615]], dtype=float32)]







  1%|▉                                                                      | 13/1000 [00:03<04:00,  4.10it/s, loss=0.664]




  1%|▉                                                                      | 13/1000 [00:03<04:00,  4.10it/s, loss=0.663]

Predictions [array([[0.16376296, 0.0702396 ],
       [0.18783653, 0.06914584],
       [0.1668402 , 0.08293086],
       [0.15823998, 0.06555255],
       [0.15167812, 0.09196582],
       [0.16467762, 0.08106106],
       [0.1476421 , 0.0863941 ],
       [0.14165585, 0.08444986],
       [0.16068825, 0.0855427 ],
       [0.1668719 , 0.07936931]], dtype=float32)]







  1%|▉                                                                      | 14/1000 [00:03<03:57,  4.15it/s, loss=0.663]




  1%|▉                                                                      | 14/1000 [00:04<03:57,  4.15it/s, loss=0.660]

Predictions [array([[0.15494879, 0.07269666],
       [0.17722112, 0.05068044],
       [0.13401133, 0.10398397],
       [0.17219403, 0.08374222],
       [0.16108224, 0.06387514],
       [0.16293411, 0.08623251],
       [0.14948991, 0.07197954],
       [0.14651851, 0.08237629],
       [0.14625478, 0.08705733],
       [0.17700744, 0.06151406]], dtype=float32)]







  2%|█                                                                      | 15/1000 [00:04<03:55,  4.18it/s, loss=0.660]




  2%|█                                                                      | 15/1000 [00:04<03:55,  4.18it/s, loss=0.660]

Predictions [array([[0.17072701, 0.08581219],
       [0.1803658 , 0.05042694],
       [0.17192969, 0.08198053],
       [0.1469752 , 0.08690222],
       [0.1661271 , 0.06937692],
       [0.12556778, 0.09863593],
       [0.16221547, 0.08099775],
       [0.15206814, 0.10903771],
       [0.17694083, 0.09316398],
       [0.17636612, 0.07089383]], dtype=float32)]







  2%|█▏                                                                     | 16/1000 [00:04<03:57,  4.14it/s, loss=0.660]




  2%|█▏                                                                     | 16/1000 [00:04<03:57,  4.14it/s, loss=0.661]

Predictions [array([[0.14734372, 0.08767136],
       [0.16242337, 0.04815438],
       [0.15484706, 0.08511129],
       [0.1746882 , 0.07853513],
       [0.16314471, 0.0838529 ],
       [0.1732827 , 0.07039986],
       [0.15226927, 0.09136797],
       [0.1760398 , 0.0484066 ],
       [0.16436036, 0.0711878 ],
       [0.1765329 , 0.06892842]], dtype=float32)]







  2%|█▏                                                                     | 17/1000 [00:04<03:57,  4.14it/s, loss=0.661]




  2%|█▏                                                                     | 17/1000 [00:04<03:57,  4.14it/s, loss=0.657]

Predictions [array([[0.167512  , 0.05809071],
       [0.19234326, 0.04387375],
       [0.17457736, 0.0458654 ],
       [0.15589347, 0.09019522],
       [0.14692073, 0.08703193],
       [0.15268612, 0.07554448],
       [0.15948471, 0.077322  ],
       [0.18062711, 0.07096893],
       [0.17559563, 0.07540622],
       [0.1625894 , 0.06499524]], dtype=float32)]







  2%|█▎                                                                     | 18/1000 [00:04<04:00,  4.08it/s, loss=0.657]




  2%|█▎                                                                     | 18/1000 [00:05<04:00,  4.08it/s, loss=0.657]

Predictions [array([[0.1548459 , 0.07395636],
       [0.1837359 , 0.07629288],
       [0.16266102, 0.08709838],
       [0.15161723, 0.08216409],
       [0.16037896, 0.06268208],
       [0.16776875, 0.06675877],
       [0.16791774, 0.08108006],
       [0.14936595, 0.08983956],
       [0.17592868, 0.09190871],
       [0.18414503, 0.03525507]], dtype=float32)]







  2%|█▎                                                                     | 19/1000 [00:05<03:58,  4.12it/s, loss=0.657]




  2%|█▎                                                                     | 19/1000 [00:05<03:58,  4.12it/s, loss=0.657]

Predictions [array([[0.14538282, 0.07906367],
       [0.17574918, 0.05330563],
       [0.16890493, 0.08076347],
       [0.18000285, 0.0670749 ],
       [0.18502942, 0.0609296 ],
       [0.15655771, 0.08394519],
       [0.16048653, 0.08745901],
       [0.18375435, 0.04352146],
       [0.16400807, 0.06074132],
       [0.15596855, 0.0831672 ]], dtype=float32)]







  2%|█▍                                                                     | 20/1000 [00:05<03:58,  4.10it/s, loss=0.657]




  2%|█▍                                                                     | 20/1000 [00:05<03:58,  4.10it/s, loss=0.655]

Predictions [array([[0.16886333, 0.05034536],
       [0.13510859, 0.08489757],
       [0.16468334, 0.05750967],
       [0.1576252 , 0.07938311],
       [0.17849323, 0.0536757 ],
       [0.14203928, 0.08891089],
       [0.17579994, 0.05703929],
       [0.1659516 , 0.07305752],
       [0.1711686 , 0.09383829],
       [0.14628887, 0.0891908 ]], dtype=float32)]







  2%|█▍                                                                     | 21/1000 [00:05<03:58,  4.10it/s, loss=0.655]




  2%|█▍                                                                     | 21/1000 [00:05<03:58,  4.10it/s, loss=0.656]

Predictions [array([[0.16517267, 0.08564322],
       [0.14699107, 0.07788668],
       [0.18139276, 0.06428631],
       [0.14765456, 0.09756557],
       [0.16360351, 0.08277209],
       [0.15303908, 0.06355473],
       [0.16999674, 0.06990914],
       [0.20091492, 0.05068766],
       [0.17616038, 0.08050944],
       [0.16507724, 0.05809193]], dtype=float32)]







  2%|█▌                                                                     | 22/1000 [00:05<03:58,  4.10it/s, loss=0.656]




  2%|█▌                                                                     | 22/1000 [00:05<03:58,  4.10it/s, loss=0.654]

Predictions [array([[0.1450266 , 0.08227263],
       [0.17955199, 0.05700043],
       [0.15781498, 0.0744656 ],
       [0.18010257, 0.06471202],
       [0.15768257, 0.08101453],
       [0.1548934 , 0.06864237],
       [0.15911244, 0.07048847],
       [0.16102037, 0.07244182],
       [0.16238372, 0.07352129],
       [0.19378453, 0.05730362]], dtype=float32)]







  2%|█▋                                                                     | 23/1000 [00:06<03:57,  4.12it/s, loss=0.654]




  2%|█▋                                                                     | 23/1000 [00:06<03:57,  4.12it/s, loss=0.654]

Predictions [array([[0.17717625, 0.06117734],
       [0.19478184, 0.03502943],
       [0.17670211, 0.07401466],
       [0.16533618, 0.07112256],
       [0.16830286, 0.06691159],
       [0.1628541 , 0.06107841],
       [0.18045379, 0.06329277],
       [0.18534939, 0.03682413],
       [0.16906184, 0.07764465],
       [0.19314902, 0.05712269]], dtype=float32)]







  2%|█▋                                                                     | 24/1000 [00:06<03:56,  4.12it/s, loss=0.654]




  2%|█▋                                                                     | 24/1000 [00:06<03:56,  4.12it/s, loss=0.651]

Predictions [array([[0.1679537 , 0.06869584],
       [0.17755115, 0.0787029 ],
       [0.15146229, 0.07246248],
       [0.16595757, 0.05575796],
       [0.21272475, 0.03286026],
       [0.17234977, 0.06917316],
       [0.18941572, 0.05859779],
       [0.17725824, 0.05069565],
       [0.17660397, 0.06673913],
       [0.16908008, 0.06176525]], dtype=float32)]







  2%|█▊                                                                     | 25/1000 [00:06<03:56,  4.12it/s, loss=0.651]




  2%|█▊                                                                     | 25/1000 [00:06<03:56,  4.12it/s, loss=0.650]

Predictions [array([[0.17117476, 0.05303575],
       [0.20292774, 0.07907753],
       [0.15067956, 0.07829007],
       [0.15693524, 0.08697898],
       [0.18547389, 0.05417927],
       [0.2152258 , 0.0348862 ],
       [0.18723822, 0.05731801],
       [0.16408147, 0.07174781],
       [0.16250584, 0.06410415],
       [0.15153338, 0.07770504]], dtype=float32)]







  3%|█▊                                                                     | 26/1000 [00:06<03:56,  4.12it/s, loss=0.650]




  3%|█▊                                                                     | 26/1000 [00:06<03:56,  4.12it/s, loss=0.650]

Predictions [array([[0.19965981, 0.06921583],
       [0.20020813, 0.06592317],
       [0.16529535, 0.0597638 ],
       [0.17119855, 0.07401762],
       [0.18678428, 0.04625631],
       [0.16062126, 0.08345464],
       [0.21558906, 0.02886958],
       [0.1670813 , 0.0580518 ],
       [0.16546333, 0.09527578],
       [0.1713318 , 0.06101586]], dtype=float32)]







  3%|█▉                                                                     | 27/1000 [00:07<03:55,  4.14it/s, loss=0.650]




  3%|█▉                                                                     | 27/1000 [00:07<03:55,  4.14it/s, loss=0.649]

Predictions [array([[0.20038046, 0.04718938],
       [0.15824838, 0.04835046],
       [0.18859793, 0.06273668],
       [0.16475798, 0.07168855],
       [0.16646332, 0.05973516],
       [0.18786623, 0.0365007 ],
       [0.1599891 , 0.07290268],
       [0.22555289, 0.00138823],
       [0.20105655, 0.02690266],
       [0.18242814, 0.07002564]], dtype=float32)]







  3%|█▉                                                                     | 28/1000 [00:07<03:55,  4.13it/s, loss=0.649]




  3%|█▉                                                                     | 28/1000 [00:07<03:55,  4.13it/s, loss=0.647]

Predictions [array([[0.2223852 , 0.03698841],
       [0.1803675 , 0.04205073],
       [0.16848937, 0.09475011],
       [0.18357435, 0.08522324],
       [0.18301968, 0.04637526],
       [0.17333475, 0.06179508],
       [0.17533113, 0.06235737],
       [0.14692071, 0.08081521],
       [0.1736229 , 0.05257995],
       [0.17759018, 0.05493119]], dtype=float32)]







  3%|██                                                                     | 29/1000 [00:07<03:58,  4.07it/s, loss=0.647]




  3%|██                                                                     | 29/1000 [00:07<03:58,  4.07it/s, loss=0.644]

Predictions [array([[0.18765265, 0.03818561],
       [0.17274082, 0.07125996],
       [0.17142436, 0.07395056],
       [0.18148354, 0.05473583],
       [0.22966573, 0.00951851],
       [0.17461109, 0.06219175],
       [0.15668675, 0.0759279 ],
       [0.14691058, 0.06935469],
       [0.17282423, 0.05050053],
       [0.21189174, 0.01222945]], dtype=float32)]







  3%|██▏                                                                    | 30/1000 [00:07<04:00,  4.03it/s, loss=0.644]




  3%|██▏                                                                    | 30/1000 [00:07<04:00,  4.03it/s, loss=0.644]

Predictions [array([[0.1774967 , 0.0537224 ],
       [0.191021  , 0.04552934],
       [0.1771599 , 0.03656948],
       [0.18190539, 0.03746055],
       [0.17350739, 0.06094775],
       [0.18794557, 0.03655447],
       [0.19689818, 0.03039227],
       [0.17007564, 0.05520978],
       [0.15854746, 0.08085708],
       [0.16582851, 0.08241974]], dtype=float32)]







  3%|██▏                                                                    | 31/1000 [00:08<04:01,  4.02it/s, loss=0.644]




  3%|██▏                                                                    | 31/1000 [00:08<04:01,  4.02it/s, loss=0.644]

Predictions [array([[0.23181775, 0.0250089 ],
       [0.17731914, 0.0508498 ],
       [0.18513384, 0.06467164],
       [0.17127493, 0.04329123],
       [0.19357905, 0.02040204],
       [0.16803151, 0.08003096],
       [0.18770936, 0.03437266],
       [0.16696519, 0.06275494],
       [0.1951553 , 0.03473065],
       [0.15825894, 0.06888057]], dtype=float32)]







  3%|██▎                                                                    | 32/1000 [00:08<04:01,  4.01it/s, loss=0.644]




  3%|██▎                                                                    | 32/1000 [00:08<04:01,  4.01it/s, loss=0.644]

Predictions [array([[0.18741547, 0.06233414],
       [0.15849519, 0.07981901],
       [0.19043994, 0.03054236],
       [0.19734803, 0.0549518 ],
       [0.17094925, 0.05295574],
       [0.14889441, 0.09594366],
       [0.21450971, 0.03176168],
       [0.17680407, 0.04775209],
       [0.17504817, 0.05542174],
       [0.1911445 , 0.04790166]], dtype=float32)]







  3%|██▎                                                                    | 33/1000 [00:08<04:02,  3.99it/s, loss=0.644]




  3%|██▎                                                                    | 33/1000 [00:08<04:02,  3.99it/s, loss=0.641]

Predictions [array([[0.19525774, 0.02373859],
       [0.17250785, 0.05561931],
       [0.16625708, 0.05496047],
       [0.21226037, 0.01478221],
       [0.18929723, 0.04745264],
       [0.18344474, 0.04348673],
       [0.14614706, 0.0817276 ],
       [0.16958308, 0.05704132],
       [0.15237245, 0.07215253],
       [0.19977756, 0.03505866]], dtype=float32)]







  3%|██▍                                                                    | 34/1000 [00:08<04:01,  4.01it/s, loss=0.641]




  3%|██▍                                                                    | 34/1000 [00:08<04:01,  4.01it/s, loss=0.639]

Predictions [array([[0.17918082, 0.03825881],
       [0.18964335, 0.05666814],
       [0.15482087, 0.06136347],
       [0.16643998, 0.05691007],
       [0.18025172, 0.06023004],
       [0.17884141, 0.04717766],
       [0.17940386, 0.04784603],
       [0.20616943, 0.04771876],
       [0.19646204, 0.02471203],
       [0.17509362, 0.04811073]], dtype=float32)]







  4%|██▍                                                                    | 35/1000 [00:09<04:02,  3.98it/s, loss=0.639]




  4%|██▍                                                                    | 35/1000 [00:09<04:02,  3.98it/s, loss=0.639]

Predictions [array([[0.21981113, 0.02029688],
       [0.16783968, 0.06913278],
       [0.19968525, 0.05066539],
       [0.16564074, 0.06066566],
       [0.18062913, 0.04226548],
       [0.1951522 , 0.04293375],
       [0.18544573, 0.04579512],
       [0.17133784, 0.0554913 ],
       [0.1469432 , 0.07938989],
       [0.17772633, 0.04600785]], dtype=float32)]







  4%|██▌                                                                    | 36/1000 [00:09<04:01,  3.99it/s, loss=0.639]




  4%|██▌                                                                    | 36/1000 [00:09<04:01,  3.99it/s, loss=0.637]

Predictions [array([[0.18754378, 0.03556842],
       [0.16687253, 0.06163403],
       [0.17577365, 0.0580076 ],
       [0.21461093, 0.01022788],
       [0.17959787, 0.05045068],
       [0.18374182, 0.03731588],
       [0.15900113, 0.07086707],
       [0.20639856, 0.04297313],
       [0.17364267, 0.05170234],
       [0.21669665, 0.00753628]], dtype=float32)]







  4%|██▋                                                                    | 37/1000 [00:09<04:00,  4.00it/s, loss=0.637]




  4%|██▋                                                                    | 37/1000 [00:09<04:00,  4.00it/s, loss=0.637]

Predictions [array([[0.1806686 , 0.05626615],
       [0.17752934, 0.04971116],
       [0.18535712, 0.04812601],
       [0.18675008, 0.04518079],
       [0.21337716, 0.01551212],
       [0.20699672, 0.02502761],
       [0.19502582, 0.04681251],
       [0.19245782, 0.0539336 ],
       [0.19316024, 0.03373343],
       [0.1390315 , 0.08399318]], dtype=float32)]







  4%|██▋                                                                    | 38/1000 [00:09<04:00,  3.99it/s, loss=0.637]




  4%|██▋                                                                    | 38/1000 [00:09<04:00,  3.99it/s, loss=0.634]

Predictions [array([[0.15369418, 0.06051218],
       [0.2063193 , 0.02743642],
       [0.21539938, 0.0006148 ],
       [0.19744019, 0.06691222],
       [0.17735952, 0.04659497],
       [0.19963798, 0.04770693],
       [0.19256608, 0.04782926],
       [0.18226868, 0.0737427 ],
       [0.1560591 , 0.06935105],
       [0.18653432, 0.02953625]], dtype=float32)]







  4%|██▊                                                                    | 39/1000 [00:10<04:03,  3.94it/s, loss=0.634]




  4%|██▊                                                                    | 39/1000 [00:10<04:03,  3.94it/s, loss=0.629]

Predictions [array([[0.1831092 , 0.05047566],
       [0.2010462 , 0.0335299 ],
       [0.20507808, 0.04240687],
       [0.19366282, 0.05291554],
       [0.23166859, 0.00125403],
       [0.19161588, 0.05555228],
       [0.19123688, 0.05312099],
       [0.14481708, 0.078499  ],
       [0.16766837, 0.05717826],
       [0.17502387, 0.05975003]], dtype=float32)]







  4%|██▊                                                                    | 40/1000 [00:10<03:59,  4.00it/s, loss=0.629]




  4%|██▊                                                                    | 40/1000 [00:10<03:59,  4.00it/s, loss=0.628]

Predictions [array([[0.19151479, 0.04235055],
       [0.18931963, 0.04756609],
       [0.17699645, 0.05085784],
       [0.18621576, 0.05216084],
       [0.16958334, 0.06224712],
       [0.18321773, 0.06309032],
       [0.2005846 , 0.04609501],
       [0.23415448, 0.01632417],
       [0.2054331 , 0.03258395],
       [0.18984619, 0.03007628]], dtype=float32)]







  4%|██▉                                                                    | 41/1000 [00:10<03:57,  4.03it/s, loss=0.628]




  4%|██▉                                                                    | 41/1000 [00:10<03:57,  4.03it/s, loss=0.628]

Predictions [array([[0.20765004, 0.04944865],
       [0.16672498, 0.05735799],
       [0.18604162, 0.05018784],
       [0.19948938, 0.03015935],
       [0.22571775, 0.00673459],
       [0.18983878, 0.05786799],
       [0.19688907, 0.03255329],
       [0.2173653 , 0.02405858],
       [0.17268318, 0.04672997],
       [0.24047965, 0.        ]], dtype=float32)]







  4%|██▉                                                                    | 42/1000 [00:10<03:57,  4.03it/s, loss=0.628]




  4%|██▉                                                                    | 42/1000 [00:10<03:57,  4.03it/s, loss=0.629]

Predictions [array([[0.20324515, 0.0305748 ],
       [0.20773697, 0.02356122],
       [0.17705381, 0.05994546],
       [0.2096991 , 0.02511935],
       [0.18205684, 0.05141997],
       [0.20518678, 0.01725023],
       [0.19440746, 0.04482951],
       [0.24045673, 0.00105328],
       [0.21441391, 0.01400394],
       [0.22302485, 0.02467843]], dtype=float32)]







  4%|███                                                                    | 43/1000 [00:11<03:57,  4.03it/s, loss=0.629]




  4%|███                                                                    | 43/1000 [00:11<03:57,  4.03it/s, loss=0.631]

Predictions [array([[0.22952327, 0.        ],
       [0.23158252, 0.00715255],
       [0.20778444, 0.04159182],
       [0.21082869, 0.03113466],
       [0.19160348, 0.03253926],
       [0.20814288, 0.02471363],
       [0.22887143, 0.00207461],
       [0.20450652, 0.02626532],
       [0.19151556, 0.04188021],
       [0.14500204, 0.08318122]], dtype=float32)]







  4%|███                                                                    | 44/1000 [00:11<03:56,  4.05it/s, loss=0.631]




  4%|███                                                                    | 44/1000 [00:11<03:56,  4.05it/s, loss=0.625]

Predictions [array([[0.18295792, 0.06571619],
       [0.2645113 , 0.        ],
       [0.22984438, 0.01600785],
       [0.23131602, 0.        ],
       [0.24246904, 0.02600651],
       [0.1902352 , 0.04788935],
       [0.2004819 , 0.02408946],
       [0.25213712, 0.        ],
       [0.207633  , 0.02144042],
       [0.23229925, 0.00609189]], dtype=float32)]







  4%|███▏                                                                   | 45/1000 [00:11<03:56,  4.03it/s, loss=0.625]




  4%|███▏                                                                   | 45/1000 [00:11<03:56,  4.03it/s, loss=0.624]

Predictions [array([[0.21561974, 0.00192595],
       [0.2089073 , 0.04457742],
       [0.24381423, 0.02088942],
       [0.25008127, 0.        ],
       [0.24811292, 0.        ],
       [0.23736668, 0.        ],
       [0.19527182, 0.05696023],
       [0.19556126, 0.04561047],
       [0.22244307, 0.00235085],
       [0.23789775, 0.        ]], dtype=float32)]







  5%|███▎                                                                   | 46/1000 [00:11<03:56,  4.03it/s, loss=0.624]




  5%|███▎                                                                   | 46/1000 [00:11<03:56,  4.03it/s, loss=0.622]

Predictions [array([[0.21394403, 0.00623463],
       [0.18548211, 0.04111941],
       [0.28701907, 0.        ],
       [0.2271133 , 0.00499964],
       [0.22701   , 0.00895926],
       [0.21595979, 0.01731163],
       [0.24387625, 0.0093732 ],
       [0.17529038, 0.03485074],
       [0.20521495, 0.02731302],
       [0.22063172, 0.02080866]], dtype=float32)]







  5%|███▎                                                                   | 47/1000 [00:12<03:58,  4.00it/s, loss=0.622]




  5%|███▎                                                                   | 47/1000 [00:12<03:58,  4.00it/s, loss=0.625]

Predictions [array([[0.20562439, 0.03641549],
       [0.26627237, 0.        ],
       [0.299851  , 0.        ],
       [0.20671351, 0.01480275],
       [0.247895  , 0.        ],
       [0.23328403, 0.03553353],
       [0.25622556, 0.00117179],
       [0.191107  , 0.03040285],
       [0.189538  , 0.03494997],
       [0.20718451, 0.03419941]], dtype=float32)]







  5%|███▍                                                                   | 48/1000 [00:12<03:55,  4.04it/s, loss=0.625]




  5%|███▍                                                                   | 48/1000 [00:12<03:55,  4.04it/s, loss=0.623]

Predictions [array([[2.5371510e-01, 0.0000000e+00],
       [2.2062844e-01, 6.2413514e-05],
       [2.0579441e-01, 5.6413144e-02],
       [1.9486067e-01, 3.5917133e-02],
       [2.0604214e-01, 4.5116071e-02],
       [1.9466239e-01, 4.2283807e-02],
       [2.3717387e-01, 2.6306368e-02],
       [1.9336283e-01, 4.9644053e-02],
       [2.1790864e-01, 3.9840087e-02],
       [2.0843358e-01, 3.5965614e-02]], dtype=float32)]







  5%|███▍                                                                   | 49/1000 [00:12<03:53,  4.07it/s, loss=0.623]




  5%|███▍                                                                   | 49/1000 [00:12<03:53,  4.07it/s, loss=0.618]

Predictions [array([[0.1850736 , 0.05134701],
       [0.22888198, 0.00404824],
       [0.20906626, 0.0241264 ],
       [0.20401737, 0.02297084],
       [0.23466067, 0.02509531],
       [0.23261786, 0.        ],
       [0.26505417, 0.        ],
       [0.22548899, 0.01343465],
       [0.21043487, 0.01664931],
       [0.24629459, 0.00580369]], dtype=float32)]







  5%|███▌                                                                   | 50/1000 [00:12<03:55,  4.03it/s, loss=0.618]




  5%|███▌                                                                   | 50/1000 [00:12<03:55,  4.03it/s, loss=0.616]

Predictions [array([[0.25407457, 0.        ],
       [0.23783606, 0.01550887],
       [0.21876454, 0.02530044],
       [0.19282863, 0.0445768 ],
       [0.3112209 , 0.        ],
       [0.20525163, 0.03100695],
       [0.22809395, 0.00976359],
       [0.25509986, 0.        ],
       [0.29358837, 0.        ],
       [0.26189235, 0.        ]], dtype=float32)]







  5%|███▌                                                                   | 51/1000 [00:13<03:54,  4.04it/s, loss=0.616]




  5%|███▌                                                                   | 51/1000 [00:13<03:54,  4.04it/s, loss=0.617]

Predictions [array([[0.21621561, 0.03349172],
       [0.22745588, 0.02874295],
       [0.22901702, 0.01539008],
       [0.2242176 , 0.01855373],
       [0.23007601, 0.00440319],
       [0.22479811, 0.01468084],
       [0.25021994, 0.        ],
       [0.24432018, 0.01786467],
       [0.19675007, 0.03461337],
       [0.2223272 , 0.01300382]], dtype=float32)]







  5%|███▋                                                                   | 52/1000 [00:13<03:55,  4.03it/s, loss=0.617]




  5%|███▋                                                                   | 52/1000 [00:13<03:55,  4.03it/s, loss=0.612]

Predictions [array([[0.2236189 , 0.02998483],
       [0.23761413, 0.        ],
       [0.21233632, 0.02608682],
       [0.22629717, 0.02238779],
       [0.30360535, 0.        ],
       [0.21661538, 0.02558269],
       [0.24711016, 0.        ],
       [0.22931224, 0.03245757],
       [0.2327441 , 0.00861204],
       [0.25619233, 0.        ]], dtype=float32)]







  5%|███▊                                                                   | 53/1000 [00:13<03:55,  4.02it/s, loss=0.612]




  5%|███▊                                                                   | 53/1000 [00:13<03:55,  4.02it/s, loss=0.615]

Predictions [array([[0.2685213 , 0.0019227 ],
       [0.3119251 , 0.        ],
       [0.20986557, 0.02969667],
       [0.23887877, 0.        ],
       [0.27881214, 0.        ],
       [0.25211504, 0.        ],
       [0.2199564 , 0.02178267],
       [0.18522093, 0.05776822],
       [0.22024322, 0.01855642],
       [0.19679978, 0.03732394]], dtype=float32)]







  5%|███▊                                                                   | 54/1000 [00:13<03:55,  4.01it/s, loss=0.615]




  5%|███▊                                                                   | 54/1000 [00:13<03:55,  4.01it/s, loss=0.606]

Predictions [array([[0.21844812, 0.00850947],
       [0.23751876, 0.01489154],
       [0.20592225, 0.03619512],
       [0.25258803, 0.        ],
       [0.26204032, 0.        ],
       [0.26485455, 0.        ],
       [0.25214928, 0.        ],
       [0.24536303, 0.00460686],
       [0.22255188, 0.01153825],
       [0.2227089 , 0.02965534]], dtype=float32)]







  6%|███▉                                                                   | 55/1000 [00:14<03:55,  4.01it/s, loss=0.606]




  6%|███▉                                                                   | 55/1000 [00:14<03:55,  4.01it/s, loss=0.608]

Predictions [array([[0.24263021, 0.        ],
       [0.23333639, 0.03933057],
       [0.2222805 , 0.01102145],
       [0.24868765, 0.00478259],
       [0.20471874, 0.03235626],
       [0.27980113, 0.0138619 ],
       [0.22245646, 0.02685147],
       [0.21129438, 0.0420179 ],
       [0.20436764, 0.04486816],
       [0.20235467, 0.02578346]], dtype=float32)]







  6%|███▉                                                                   | 56/1000 [00:14<03:56,  3.99it/s, loss=0.608]




  6%|███▉                                                                   | 56/1000 [00:14<03:56,  3.99it/s, loss=0.611]

Predictions [array([[0.2632423 , 0.        ],
       [0.26458117, 0.        ],
       [0.2680001 , 0.        ],
       [0.2553093 , 0.        ],
       [0.2414099 , 0.01037101],
       [0.29223478, 0.        ],
       [0.21116552, 0.0176089 ],
       [0.25618365, 0.        ],
       [0.20349087, 0.02920515],
       [0.32809356, 0.        ]], dtype=float32)]







  6%|████                                                                   | 57/1000 [00:14<03:57,  3.97it/s, loss=0.611]




  6%|████                                                                   | 57/1000 [00:14<03:57,  3.97it/s, loss=0.605]

Predictions [array([[0.23072186, 0.02996601],
       [0.25852025, 0.        ],
       [0.25950164, 0.        ],
       [0.31131276, 0.        ],
       [0.23485462, 0.0230843 ],
       [0.26939896, 0.        ],
       [0.29952163, 0.02427979],
       [0.27021444, 0.        ],
       [0.23169492, 0.        ],
       [0.25925943, 0.        ]], dtype=float32)]







  6%|████                                                                   | 58/1000 [00:14<03:58,  3.95it/s, loss=0.605]




  6%|████                                                                   | 58/1000 [00:14<03:58,  3.95it/s, loss=0.605]

Predictions [array([[0.28154388, 0.        ],
       [0.29243436, 0.        ],
       [0.23995581, 0.00152265],
       [0.20101058, 0.02291851],
       [0.20279342, 0.03458726],
       [0.2330249 , 0.0037958 ],
       [0.28877035, 0.        ],
       [0.2811996 , 0.        ],
       [0.25683382, 0.        ],
       [0.2879458 , 0.        ]], dtype=float32)]







  6%|████▏                                                                  | 59/1000 [00:15<03:59,  3.92it/s, loss=0.605]




  6%|████▏                                                                  | 59/1000 [00:15<03:59,  3.92it/s, loss=0.604]

Predictions [array([[0.31676942, 0.        ],
       [0.23287563, 0.0272053 ],
       [0.22498426, 0.        ],
       [0.2793728 , 0.        ],
       [0.22705668, 0.01844557],
       [0.24233708, 0.01166736],
       [0.24884163, 0.00756293],
       [0.28299797, 0.        ],
       [0.2489463 , 0.00043449],
       [0.23050407, 0.        ]], dtype=float32)]







  6%|████▎                                                                  | 60/1000 [00:15<03:57,  3.96it/s, loss=0.604]




  6%|████▎                                                                  | 60/1000 [00:15<03:57,  3.96it/s, loss=0.608]

Predictions [array([[0.24783787, 0.01434073],
       [0.24624577, 0.01926347],
       [0.2703619 , 0.00190991],
       [0.21461903, 0.02819016],
       [0.2436942 , 0.00626449],
       [0.27013844, 0.00980914],
       [0.29090038, 0.        ],
       [0.32849392, 0.        ],
       [0.25463814, 0.        ],
       [0.25033477, 0.00559039]], dtype=float32)]







  6%|████▎                                                                  | 61/1000 [00:15<04:00,  3.91it/s, loss=0.608]




  6%|████▎                                                                  | 61/1000 [00:15<04:00,  3.91it/s, loss=0.596]

Predictions [array([[0.27765423, 0.        ],
       [0.27964354, 0.        ],
       [0.26639628, 0.        ],
       [0.33859903, 0.        ],
       [0.2888524 , 0.        ],
       [0.25210154, 0.0229313 ],
       [0.3027352 , 0.        ],
       [0.26323372, 0.00039578],
       [0.22244288, 0.02838356],
       [0.21854228, 0.02555859]], dtype=float32)]







  6%|████▍                                                                  | 62/1000 [00:15<03:58,  3.93it/s, loss=0.596]




  6%|████▍                                                                  | 62/1000 [00:15<03:58,  3.93it/s, loss=0.598]

Predictions [array([[0.32648644, 0.        ],
       [0.31086603, 0.        ],
       [0.24091971, 0.01256839],
       [0.2803912 , 0.        ],
       [0.31659046, 0.        ],
       [0.31692484, 0.        ],
       [0.26497594, 0.01354978],
       [0.27417165, 0.        ],
       [0.2573095 , 0.        ],
       [0.23470747, 0.02996776]], dtype=float32)]







  6%|████▍                                                                  | 63/1000 [00:16<03:56,  3.97it/s, loss=0.598]




  6%|████▍                                                                  | 63/1000 [00:16<03:56,  3.97it/s, loss=0.597]

Predictions [array([[0.2663852 , 0.03732617],
       [0.28753245, 0.        ],
       [0.22906624, 0.02630646],
       [0.29920602, 0.        ],
       [0.2724723 , 0.        ],
       [0.2522061 , 0.        ],
       [0.32652062, 0.        ],
       [0.22139847, 0.04453403],
       [0.28544378, 0.        ],
       [0.31757283, 0.        ]], dtype=float32)]







  6%|████▌                                                                  | 64/1000 [00:16<03:54,  3.99it/s, loss=0.597]




  6%|████▌                                                                  | 64/1000 [00:16<03:54,  3.99it/s, loss=0.599]

Predictions [array([[0.24865043, 0.02267911],
       [0.27751967, 0.        ],
       [0.27355942, 0.        ],
       [0.28139603, 0.        ],
       [0.24423411, 0.01690558],
       [0.28695968, 0.        ],
       [0.24464105, 0.04429917],
       [0.2282978 , 0.03791159],
       [0.2838968 , 0.01226396],
       [0.2417076 , 0.02731138]], dtype=float32)]







  6%|████▌                                                                  | 65/1000 [00:16<03:54,  3.98it/s, loss=0.599]




  6%|████▌                                                                  | 65/1000 [00:16<03:54,  3.98it/s, loss=0.592]

Predictions [array([[0.3059768 , 0.        ],
       [0.39360255, 0.        ],
       [0.25448915, 0.00510871],
       [0.26208884, 0.02473585],
       [0.2326585 , 0.04408923],
       [0.27099574, 0.        ],
       [0.24453972, 0.01103938],
       [0.31381842, 0.        ],
       [0.3107799 , 0.        ],
       [0.36664516, 0.        ]], dtype=float32)]







  7%|████▋                                                                  | 66/1000 [00:16<03:56,  3.95it/s, loss=0.592]




  7%|████▋                                                                  | 66/1000 [00:16<03:56,  3.95it/s, loss=0.590]

Predictions [array([[0.26980028, 0.        ],
       [0.31067616, 0.        ],
       [0.35051847, 0.        ],
       [0.27039018, 0.00260112],
       [0.296324  , 0.        ],
       [0.27736318, 0.        ],
       [0.23544905, 0.02825163],
       [0.340401  , 0.        ],
       [0.34258655, 0.        ],
       [0.2662024 , 0.        ]], dtype=float32)]







  7%|████▊                                                                  | 67/1000 [00:17<03:58,  3.92it/s, loss=0.590]




  7%|████▊                                                                  | 67/1000 [00:17<03:58,  3.92it/s, loss=0.589]

Predictions [array([[0.2861545 , 0.        ],
       [0.27404606, 0.00210324],
       [0.32581484, 0.        ],
       [0.29074758, 0.        ],
       [0.2960712 , 0.        ],
       [0.3390612 , 0.        ],
       [0.29975313, 0.00176708],
       [0.34558922, 0.        ],
       [0.2911339 , 0.        ],
       [0.3206093 , 0.        ]], dtype=float32)]







  7%|████▊                                                                  | 68/1000 [00:17<03:54,  3.97it/s, loss=0.589]




  7%|████▊                                                                  | 68/1000 [00:17<03:54,  3.97it/s, loss=0.587]

Predictions [array([[0.27401692, 0.        ],
       [0.34857798, 0.        ],
       [0.2658667 , 0.02700671],
       [0.31079075, 0.        ],
       [0.3043474 , 0.        ],
       [0.35455337, 0.        ],
       [0.32701373, 0.        ],
       [0.2830756 , 0.        ],
       [0.2740509 , 0.        ],
       [0.38393664, 0.        ]], dtype=float32)]







  7%|████▉                                                                  | 69/1000 [00:17<03:54,  3.97it/s, loss=0.587]




  7%|████▉                                                                  | 69/1000 [00:17<03:54,  3.97it/s, loss=0.590]

Predictions [array([[0.2755202 , 0.        ],
       [0.29348278, 0.        ],
       [0.39872265, 0.        ],
       [0.30471626, 0.        ],
       [0.21913871, 0.02582741],
       [0.2763331 , 0.04509306],
       [0.24460855, 0.01698097],
       [0.29622072, 0.        ],
       [0.24391659, 0.02543663],
       [0.32175106, 0.        ]], dtype=float32)]







  7%|████▉                                                                  | 70/1000 [00:17<03:53,  3.98it/s, loss=0.590]




  7%|████▉                                                                  | 70/1000 [00:17<03:53,  3.98it/s, loss=0.578]

Predictions [array([[0.29921436, 0.        ],
       [0.3016637 , 0.        ],
       [0.24094808, 0.02804405],
       [0.29105544, 0.        ],
       [0.32643348, 0.        ],
       [0.30985656, 0.00169771],
       [0.23102455, 0.01558714],
       [0.25852084, 0.01258781],
       [0.33309212, 0.        ],
       [0.29093444, 0.        ]], dtype=float32)]







  7%|█████                                                                  | 71/1000 [00:18<03:54,  3.97it/s, loss=0.578]




  7%|█████                                                                  | 71/1000 [00:18<03:54,  3.97it/s, loss=0.582]

Predictions [array([[0.27076724, 0.        ],
       [0.31621128, 0.        ],
       [0.27017757, 0.00877883],
       [0.29370794, 0.        ],
       [0.2857883 , 0.        ],
       [0.28115782, 0.        ],
       [0.31940544, 0.        ],
       [0.2782238 , 0.        ],
       [0.34073278, 0.        ],
       [0.37597397, 0.        ]], dtype=float32)]







  7%|█████                                                                  | 72/1000 [00:18<03:53,  3.98it/s, loss=0.582]




  7%|█████                                                                  | 72/1000 [00:18<03:53,  3.98it/s, loss=0.578]

Predictions [array([[3.7915263e-01, 0.0000000e+00],
       [2.6006487e-01, 0.0000000e+00],
       [3.1059375e-01, 0.0000000e+00],
       [3.4530053e-01, 0.0000000e+00],
       [3.0342254e-01, 3.2828748e-04],
       [2.9762971e-01, 0.0000000e+00],
       [2.8972396e-01, 0.0000000e+00],
       [2.6414406e-01, 6.9369972e-03],
       [3.7369546e-01, 0.0000000e+00],
       [2.8961200e-01, 0.0000000e+00]], dtype=float32)]







  7%|█████▏                                                                 | 73/1000 [00:18<04:07,  3.74it/s, loss=0.578]




  7%|█████▏                                                                 | 73/1000 [00:18<04:07,  3.74it/s, loss=0.578]

Predictions [array([[0.42705753, 0.        ],
       [0.3445332 , 0.        ],
       [0.3241655 , 0.        ],
       [0.36583892, 0.        ],
       [0.36231244, 0.        ],
       [0.34600008, 0.        ],
       [0.327605  , 0.        ],
       [0.3010012 , 0.        ],
       [0.33853495, 0.        ],
       [0.2995871 , 0.        ]], dtype=float32)]







  7%|█████▎                                                                 | 74/1000 [00:18<04:06,  3.75it/s, loss=0.578]




  7%|█████▎                                                                 | 74/1000 [00:19<04:06,  3.75it/s, loss=0.573]

Predictions [array([[0.33347157, 0.        ],
       [0.29243618, 0.        ],
       [0.2709846 , 0.        ],
       [0.39594   , 0.        ],
       [0.25125843, 0.01280931],
       [0.38402444, 0.        ],
       [0.29871237, 0.        ],
       [0.30515277, 0.01320011],
       [0.3537882 , 0.        ],
       [0.32623968, 0.        ]], dtype=float32)]







  8%|█████▎                                                                 | 75/1000 [00:19<04:08,  3.72it/s, loss=0.573]




  8%|█████▎                                                                 | 75/1000 [00:19<04:08,  3.72it/s, loss=0.575]

Predictions [array([[0.39641577, 0.        ],
       [0.37053317, 0.        ],
       [0.2721934 , 0.        ],
       [0.37044683, 0.        ],
       [0.2547342 , 0.02028164],
       [0.37386218, 0.        ],
       [0.33743104, 0.        ],
       [0.38231716, 0.        ],
       [0.35307115, 0.        ],
       [0.31094337, 0.        ]], dtype=float32)]







  8%|█████▍                                                                 | 76/1000 [00:19<04:06,  3.75it/s, loss=0.575]

KeyboardInterrupt: 

In [None]:
#-----------------Main function for testing

tf.reset_default_graph()
eval_model_spec = modelTest_fn("test","./evalUnlabelledEmbeddings.tfrecords",params)
logging.info("Starting training for {} epoch(s)".format(params["EPOCHS"]))
df = evaluate(eval_model_spec, "./ModelLogs", params,"best_weights")

import csv
tmp = (df.groupby('query_id')['predictions']
       .apply(lambda x: "\t".join([format(val, "0.2f") for val in x]))
       .reset_index())

tmp.to_csv("./answer.tsv",index=False,sep= "\t",header=None,quoting=csv.QUOTE_NONE,quotechar="",  escapechar="\\")

In [None]:
#df.head(100)
#map(str,x)
import csv
tmp = (df.groupby('query_id')['predictions']
       .apply(lambda x: "\t".join([format(val, "0.2f") for val in x]))
       .reset_index())

tmp.to_csv("./answer.tsv",index=False,sep= "\t",header=None,quoting=csv.QUOTE_NONE,quotechar="",  escapechar="\\")

In [None]:
df[df["query_id"] == 89]["predictions"].apply(lambda x: "\t".join(format(x, "0.2f")))

In [None]:
tmp.head()