In [17]:
import tensorflow as tf
import os
import json
from transformers import TFBertModel
answer_types=5
model_dir = "/Users/aashnabanerjee/Documents/Cortx/inference/code/bert-joint-baseline"
MODEL_NAME = "bert-large-uncased-whole-word-masking-finetuned-squad"
from adamw_optimizer import AdamW
train_filename = os.path.join(model_dir, "nq-train.tfrecords-00000-of-00001")
# val_filename = os.path.join(model_dir, "eval.tf_record")
val_filename2 = os.path.join(model_dir, "op.tf_record")
num_train_examples = 4000
epochs=2
train_batch_size = 2
batch_accumulation_size = 2
init_learning_rate = 5e-5
cyclic_learning_rate = True
init_weight_decay_rate = 0.01
num_warmup_steps = 0
shuffle_buffer_size = 100000
max_seq_length_for_training = 512

In [18]:
def get_dataset(tf_record_file, seq_length, batch_size=1, shuffle_buffer_size=0, is_training=False):

    if is_training:
        features = {
            "unique_ids": tf.io.FixedLenFeature([], tf.int64),
            "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
            "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "start_positions": tf.io.FixedLenFeature([], tf.int64),
            "end_positions": tf.io.FixedLenFeature([], tf.int64),
            "answer_types": tf.io.FixedLenFeature([], tf.int64)
        }
    else:
        features = {
            "unique_ids": tf.io.FixedLenFeature([], tf.int64),
            "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
            "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
            "token_map": tf.io.FixedLenFeature([seq_length], tf.int64)
        }        
    def decode_record(record, features):
        """Decodes a record to a TensorFlow example."""
        example = tf.io.parse_single_example(record, features)
        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in list(example.keys()):
                        
            t = example[name]
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
            example[name] = t
        return example

    def select_data_from_record(record):
        
        x = {
            'unique_ids': record['unique_ids'],
            'input_ids': record['input_ids'],
            'input_mask': record['input_mask'],
            'segment_ids': record['segment_ids']
        }
        if not is_training:
            x['token_map'] = record['token_map']

        if is_training:
            y = {
                'start_positions': record['start_positions'],
                'end_positions': record['end_positions'],
                'answer_types': record['answer_types']
            }

            return (x, y)
        
        return x

    dataset = tf.data.TFRecordDataset(tf_record_file)
    
    dataset = dataset.map(lambda record: decode_record(record, features))
    dataset = dataset.map(select_data_from_record)
    
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    
    dataset = dataset.batch(batch_size)
    
    return dataset

In [19]:
train_dataset = get_dataset(train_filename,
                    seq_length=512,
                    batch_size=2,
                    shuffle_buffer_size=10000,
                    is_training=True
                ) 

In [20]:
from transformers import TFBertMainLayer, TFBertPreTrainedModel, BertTokenizer
from transformers.modeling_tf_utils import get_initializer

class TFNQModel(TFBertPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        
        TFBertPreTrainedModel.__init__(self, config, *inputs, **kwargs)  # explicit calls without super
        self.bert = TFBertMainLayer(config, name='bert')
        self.backend = None
        
        self.seq_output_dropout = tf.keras.layers.Dropout(kwargs.get('seq_output_dropout_prob', 0.05))
        self.pooled_output_dropout = tf.keras.layers.Dropout(kwargs.get('pooled_output_dropout_prob', 0.05))
        
        self.pos_classifier = tf.keras.layers.Dense(2,
                                        kernel_initializer=get_initializer(config.initializer_range),
                                        name='pos_classifier')       

        self.answer_type_classifier = tf.keras.layers.Dense(answer_types,
                                        kernel_initializer=get_initializer(config.initializer_range),
                                        name='answer_type_classifier')         
                

    def call(self, inputs, **kwargs):
        
        # sequence / [CLS] outputs from original bert
        outputs = self.bert(inputs, **kwargs)
        sequence_output, pooled_output = outputs[0], outputs[1] 
        
        # dropout
        sequence_output = self.seq_output_dropout(sequence_output, training=kwargs.get('training', False))
        pooled_output = self.pooled_output_dropout(pooled_output, training=kwargs.get('training', False))
        
        pos_logits = self.pos_classifier(sequence_output)  # shape = (batch_size, seq_len, 2)
        start_pos_logits = pos_logits[:, :, 0]  # shape = (batch_size, seq_len)
        end_pos_logits = pos_logits[:, :, 1]  # shape = (batch_size, seq_len)
        
        answer_type_logits = self.answer_type_classifier(pooled_output)  # shape = (batch_size, NB_ANSWER_TYPES)

        outputs = (start_pos_logits, end_pos_logits, answer_type_logits)

        return outputs
        
    
def get_pretrained_model(MODEL_NAME):
    
    pretrained_path = os.path.join(model_dir, MODEL_NAME)
    
    tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
    model = TFNQModel.from_pretrained(MODEL_NAME)
    
    return tokenizer, model
bert_tokenizer, bert_nq = get_pretrained_model(MODEL_NAME)

In [21]:
input_ids = tf.constant(bert_tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
input_masks = tf.constant(0, shape=input_ids.shape)
segment_ids = tf.constant(0, shape=input_ids.shape)

# Actual inputs to model
inputs = (input_ids, input_masks, segment_ids)

# Outputs from bert_for_nq using backend_call()
outputs = bert_nq(inputs)
(start_pos_logits, end_pos_logits, answer_type_logits) = outputs
print(start_pos_logits.shape)
print(end_pos_logits.shape)
print(answer_type_logits.shape)

len(bert_nq.trainable_variables)

(1, 8)
(1, 8)
(1, 5)


395

In [22]:
def get_metrics(name):

    loss = tf.keras.metrics.Mean(name=f'{name}_loss')
    loss_start_pos = tf.keras.metrics.Mean(name=f'{name}_loss_start_pos')
    loss_end_pos = tf.keras.metrics.Mean(name=f'{name}_loss_end_pos')
    loss_ans_type = tf.keras.metrics.Mean(name=f'{name}_loss_ans_type')
    
    acc = tf.keras.metrics.SparseCategoricalAccuracy(name=f'{name}_acc')
    acc_start_pos = tf.keras.metrics.SparseCategoricalAccuracy(name=f'{name}_acc_start_pos')
    acc_end_pos = tf.keras.metrics.SparseCategoricalAccuracy(name=f'{name}_acc_end_pos')
    acc_ans_type = tf.keras.metrics.SparseCategoricalAccuracy(name=f'{name}_acc_ans_type')
    
    return loss, loss_start_pos, loss_end_pos, loss_ans_type, acc, acc_start_pos, acc_end_pos, acc_ans_type

train_loss, train_loss_start_pos, train_loss_end_pos, train_loss_ans_type, train_acc, train_acc_start_pos, train_acc_end_pos, train_acc_ans_type = get_metrics("train")
valid_loss, valid_loss_start_pos, valid_loss_end_pos, valid_loss_ans_type, valid_acc, valid_acc_start_pos, valid_acc_end_pos, valid_acc_ans_type = get_metrics("valid")

In [23]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(nq_labels, nq_logits):
    
    (start_pos_labels, end_pos_labels, answer_type_labels) = nq_labels
    (start_pos_logits, end_pos_logits, answer_type_logits) = nq_logits
    
    loss_start_pos = loss_object(start_pos_labels, start_pos_logits)
    loss_end_pos = loss_object(end_pos_labels, end_pos_logits)
    loss_ans_type = loss_object(answer_type_labels, answer_type_logits)
    
    loss_start_pos = tf.math.reduce_sum(loss_start_pos)
    loss_end_pos = tf.math.reduce_sum(loss_end_pos)
    loss_ans_type = tf.math.reduce_sum(loss_ans_type)
    
    loss = (loss_start_pos + loss_end_pos + loss_ans_type) / 3.0
    
    return loss, loss_start_pos, loss_end_pos, loss_ans_type

In [24]:
class CustomSchedule(tf.keras.optimizers.schedules.PolynomialDecay):
    
    def __init__(self,
      initial_learning_rate,
      decay_steps,
      end_learning_rate=0.0001,
      power=1.0,
      cycle=False,
      name=None,
      num_warmup_steps=1000):
        
        # Since we have a custom __call__() method, we pass cycle=False when calling `super().__init__()` and
        # in self.__call__(), we simply do `step = step % self.decay_steps` to have cyclic behavior.
        super(CustomSchedule, self).__init__(initial_learning_rate, decay_steps, end_learning_rate, power, cycle=False, name=name)
        
        self.num_warmup_steps = num_warmup_steps
        
        self.cycle = tf.constant(cycle, dtype=tf.bool)
        
    def __call__(self, step):
        """ `step` is actually the step index, starting at 0.
        """
        
        # For cyclic behavior
        step = tf.cond(self.cycle and step >= self.decay_steps, lambda: step % self.decay_steps, lambda: step)
        
        learning_rate = super(CustomSchedule, self).__call__(step)

        # Copy (including the comments) from original bert optimizer with minor change.
        # Ref: https://github.com/google-research/bert/blob/master/optimization.py#L25
        
        # Implements linear warmup: if global_step < num_warmup_steps, the
        # learning rate will be `global_step / num_warmup_steps * init_lr`.
        if self.num_warmup_steps > 0:
            
            steps_int = tf.cast(step, tf.int32)
            warmup_steps_int = tf.constant(self.num_warmup_steps, dtype=tf.int32)

            steps_float = tf.cast(steps_int, tf.float32)
            warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

            # The first training step has index (`step`) 0.
            # The original code use `steps_float / warmup_steps_float`, which gives `warmup_percent_done` being 0,
            # and causing `learning_rate` = 0, which is undesired.
            # For this reason, we use `(steps_float + 1) / warmup_steps_float`.
            # At `step = warmup_steps_float - 1`, i.e , at the `warmup_steps_float`-th step, 
            #`learning_rate` is `self.initial_learning_rate`.
            warmup_percent_done = (steps_float + 1) / warmup_steps_float
            
            warmup_learning_rate = self.initial_learning_rate * warmup_percent_done

            is_warmup = tf.cast(steps_int < warmup_steps_int, tf.float32)
            learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
                        
        return learning_rate
 

In [25]:
   
num_train_steps = int(epochs * num_train_examples / train_batch_size / batch_accumulation_size)

learning_rate = CustomSchedule(initial_learning_rate=init_learning_rate,
    decay_steps=num_train_steps,
    end_learning_rate=init_learning_rate,
    power=1.0,
    cycle=cyclic_learning_rate,    
    num_warmup_steps=num_warmup_steps
)


In [26]:
decay_var_list = []

for i in range(len(bert_nq.trainable_variables)):
    name = bert_nq.trainable_variables[i].name
    if any(x in name for x in ["LayerNorm", "layer_norm", "bias"]):
        decay_var_list.append(name)
        
decay_var_list

['tfnq_model_1/bert/embeddings/LayerNorm/gamma:0',
 'tfnq_model_1/bert/embeddings/LayerNorm/beta:0',
 'tfnq_model_1/bert/encoder/layer_._0/attention/self/query/bias:0',
 'tfnq_model_1/bert/encoder/layer_._0/attention/self/key/bias:0',
 'tfnq_model_1/bert/encoder/layer_._0/attention/self/value/bias:0',
 'tfnq_model_1/bert/encoder/layer_._0/attention/output/dense/bias:0',
 'tfnq_model_1/bert/encoder/layer_._0/attention/output/LayerNorm/gamma:0',
 'tfnq_model_1/bert/encoder/layer_._0/attention/output/LayerNorm/beta:0',
 'tfnq_model_1/bert/encoder/layer_._0/intermediate/dense/bias:0',
 'tfnq_model_1/bert/encoder/layer_._0/output/dense/bias:0',
 'tfnq_model_1/bert/encoder/layer_._0/output/LayerNorm/gamma:0',
 'tfnq_model_1/bert/encoder/layer_._0/output/LayerNorm/beta:0',
 'tfnq_model_1/bert/encoder/layer_._1/attention/self/query/bias:0',
 'tfnq_model_1/bert/encoder/layer_._1/attention/self/key/bias:0',
 'tfnq_model_1/bert/encoder/layer_._1/attention/self/value/bias:0',
 'tfnq_model_1/bert/e

In [27]:
from tensorflow.keras.optimizers import Adam
# The hyperparameters are copied from AdamWeightDecayOptimizer in original bert code.
# (https://github.com/google-research/bert/blob/master/optimization.py#L25)
optimizer = AdamW(weight_decay=init_weight_decay_rate, learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-6, decay_var_list=decay_var_list)

In [28]:
input_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32)
]

In [29]:
def get_loss_and_gradients(input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels):
    
    nq_inputs = (input_ids, input_masks, segment_ids)
    nq_labels = (start_pos_labels, end_pos_labels, answer_type_labels)

    with tf.GradientTape() as tape:

        nq_logits = bert_nq(nq_inputs, training=True)
        loss, loss_start_pos, loss_end_pos, loss_ans_type = loss_function(nq_labels, nq_logits)
                
    gradients = tape.gradient(loss, bert_nq.trainable_variables)        
        
    (start_pos_logits, end_pos_logits, answer_type_logits) = nq_logits
        
    train_acc.update_state(start_pos_labels, start_pos_logits)
    train_acc.update_state(end_pos_labels, end_pos_logits)
    train_acc.update_state(answer_type_labels, answer_type_logits)

    train_acc_start_pos.update_state(start_pos_labels, start_pos_logits)
    train_acc_end_pos.update_state(end_pos_labels, end_pos_logits)
    train_acc_ans_type.update_state(answer_type_labels, answer_type_logits)
    
    return loss, gradients, loss_start_pos, loss_end_pos, loss_ans_type


@tf.function(input_signature=input_signature)
def train_step_simple(input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels):

    nb_examples = tf.math.reduce_sum(tf.cast(tf.math.not_equal(start_pos_labels, -2), tf.int32))
    
    loss, gradients, loss_start_pos, loss_end_pos, loss_ans_type = get_loss_and_gradients(input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels)
    
    average_loss = tf.math.divide(loss, tf.cast(nb_examples, tf.float32))
    average_gradients = [tf.divide(x, tf.cast(nb_examples, tf.float32)) for x in gradients]
    
    optimizer.apply_gradients(zip(gradients, bert_nq.trainable_variables))

    average_loss_start_pos = tf.math.divide(loss_start_pos, tf.cast(nb_examples, tf.float32))
    average_loss_end_pos = tf.math.divide(loss_end_pos, tf.cast(nb_examples, tf.float32))
    average_loss_ans_type = tf.math.divide(loss_ans_type, tf.cast(nb_examples, tf.float32))
    
    train_loss(average_loss)
    train_loss_start_pos(average_loss_start_pos)
    train_loss_end_pos(average_loss_end_pos)
    train_loss_ans_type(average_loss_ans_type)

In [30]:
@tf.function(input_signature=input_signature)
def train_step_with_batch_accumulation(input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels):

    # This gets None! (probably due to input_signature)
    # batch_size = input_ids.shape[0]
    
    # Try this.
    nb_examples = tf.math.reduce_sum(tf.cast(tf.math.not_equal(start_pos_labels, -2), tf.int32))

    total_loss = 0.0
    total_loss_start_pos = 0.0
    total_loss_end_pos = 0.0
    total_loss_ans_type = 0.0
    
    total_gradients = [tf.constant(0, shape=x.shape, dtype=tf.float32) for x in bert_nq.trainable_variables]        
    ### total_gradients_sparse = [tf.IndexedSlices(values=tf.constant(0.0, shape=[1] + x.shape.as_list()[1:]), indices=tf.constant([0], dtype=tf.int32), dense_shape=x.shape.as_list()) for x in bert_nq.trainable_variables]        

    for idx in tf.range(batch_accumulation_size):   
        start_idx = train_batch_size * idx
        end_idx = train_batch_size * (idx + 1)
        
        if start_idx >= nb_examples:
            break

        (input_ids_mini, input_masks_mini, segment_ids_mini) = (input_ids[start_idx:end_idx], input_masks[start_idx:end_idx], segment_ids[start_idx:end_idx])
        (start_pos_labels_mini, end_pos_labels_mini, answer_type_labels_mini) = (start_pos_labels[start_idx:end_idx], end_pos_labels[start_idx:end_idx], answer_type_labels[start_idx:end_idx])
        
        loss, gradients, loss_start_pos, loss_end_pos, loss_ans_type = get_loss_and_gradients(input_ids_mini, input_masks_mini, segment_ids_mini, start_pos_labels_mini, end_pos_labels_mini, answer_type_labels_mini)
        
        total_loss += loss
        total_loss_start_pos += loss_start_pos
        total_loss_end_pos += loss_end_pos
        total_loss_ans_type += loss_ans_type
        
        total_gradients = [x + y for x, y in zip(total_gradients, gradients)]  
    
    average_loss = tf.math.divide(total_loss, tf.cast(nb_examples, tf.float32))        
    average_gradients = [tf.divide(x, tf.cast(nb_examples, tf.float32)) for x in total_gradients]
    ### average_gradients_sparse = [tf.scalar_mul(tf.divide(1.0, tf.cast(nb_examples, tf.float32)), x) for x in total_gradients_sparse]
    
    optimizer.apply_gradients(zip(average_gradients, bert_nq.trainable_variables))
    ### optimizer.apply_gradients(zip(average_gradients_sparse, bert_nq.trainable_variables))

    average_loss_start_pos = tf.math.divide(total_loss_start_pos, tf.cast(nb_examples, tf.float32))
    average_loss_end_pos = tf.math.divide(total_loss_end_pos, tf.cast(nb_examples, tf.float32))
    average_loss_ans_type = tf.math.divide(total_loss_ans_type, tf.cast(nb_examples, tf.float32))    
    
    train_loss(average_loss)
    train_loss_start_pos(average_loss_start_pos)
    train_loss_end_pos(average_loss_end_pos)
    train_loss_ans_type(average_loss_ans_type)


In [31]:
checkpoint_path = model_dir
ckpt = tf.train.Checkpoint(model=bert_nq)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10000)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    last_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])
    print (f'Latest BertNQ checkpoint restored -- Model trained for {last_epoch} epochs')
else:
    print('Checkpoint not found. Train BertNQ from scratch')
    last_epoch = 0
    

Checkpoint not found. Train BertNQ from scratch


In [None]:
import datetime
if batch_accumulation_size > 1:
    train_step = train_step_with_batch_accumulation


train_start_time = datetime.datetime.now()

epochs = epochs
for epoch in range(epochs):

    train_dataset = get_dataset(
        train_filename,
        max_seq_length_for_training,
        batch_accumulation_size * train_batch_size,
        shuffle_buffer_size,
        is_training=True
    )     
    train_loss.reset_states()
    train_loss_start_pos.reset_states()
    train_loss_end_pos.reset_states()
    train_loss_ans_type.reset_states()    
    
    train_acc.reset_states()
    train_acc_start_pos.reset_states()
    train_acc_end_pos.reset_states()
    train_acc_ans_type.reset_states()
    
    epoch_start_time = datetime.datetime.now()
    
    for (batch_idx, (features, targets)) in enumerate(train_dataset):
                
        
        (input_ids, input_masks, segment_ids) = (features['input_ids'], features['input_mask'], features['segment_ids'])
        (start_pos_labels, end_pos_labels, answer_type_labels) = (targets['start_positions'], targets['end_positions'], targets['answer_types'])
        batch_start_time = datetime.datetime.now()
        
        train_step(input_ids, input_masks, segment_ids, start_pos_labels, end_pos_labels, answer_type_labels)

        batch_end_time = datetime.datetime.now()
        batch_elapsed_time = (batch_end_time - batch_start_time).total_seconds()
        
        if (batch_idx + 1) % 1 == 0:
            print('Epoch {} | Batch {} | Elapsed Time {}'.format(
                epoch + 1,
                batch_idx + 1,
                batch_elapsed_time
            ))
            print('Loss {:.6f} | Loss_S {:.6f} | Loss_E {:.6f} | Loss_T {:.6f}'.format(
                train_loss.result(),
                train_loss_start_pos.result(),
                train_loss_end_pos.result(),
                train_loss_ans_type.result()
            ))
            print(' Acc {:.6f} |  Acc_S {:.6f} |  Acc_E {:.6f} |  Acc_T {:.6f}'.format(
                train_acc.result(),train_acc_start_pos.result(),
                train_acc_end_pos.result(),
                train_acc_ans_type.result()
            ))
            print("-" * 100)
       
    epoch_end_time = datetime.datetime.now()
    epoch_elapsed_time = (epoch_end_time - epoch_start_time).total_seconds()
            
    if (epoch + 1) % 1 == 0:
        
        ckpt_save_path = ckpt_manager.save()
        print ('\nSaving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
        
        print('\nEpoch {}'.format(epoch + 1))
        print('Loss {:.6f} | Loss_S {:.6f} | Loss_E {:.6f} | Loss_T {:.6f}'.format(
            train_loss.result(),
            train_loss_start_pos.result(),
            train_loss_end_pos.result(),
            train_loss_ans_type.result()
        ))
        print(' Acc {:.6f} |  Acc_S {:.6f} |  Acc_E {:.6f} |  Acc_T {:.6f}'.format(
            train_acc.result(),
            train_acc_start_pos.result(),
            train_acc_end_pos.result(),
            train_acc_ans_type.result()
        ))

    print('\nTime taken for 1 epoch: {} secs\n'.format(epoch_elapsed_time))
    print("-" * 80 + "\n")

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Epoch 1 | Batch 1 | Elapsed Time 108.805175
Loss 4.870275 | Loss_S 6.320004 | Loss_E 6.687675 | Loss_T 1.603146
 Acc 0.000000 |  Acc_S 0.000000 |  Acc_E 0.000000 |  Acc_T 0.000000
----------------------------------------------------------------------------------------------------
Epoch 1 | Batch 2 | Elapsed Time 62.339668
Loss 4.520985 | Loss_S 5.797989 | Loss_E 6.140835 | Loss_T 1.624130
 Acc 0.041667 |  Acc_S 0.125000 |  Acc_E 0.000000 |  Acc_T 0.000000
----------------------------------------------------------------------------------------------------
Epoch 1 | Batch 3 | Elapsed Time 65.431683
Loss 4.309756 | Loss_S 5.542434 | Loss_E 5.868500 | Loss_T 1.518332
 Acc 0.138889 |  Acc_S 0.166667 |  Acc_E 0.000000 |  Acc_T 0.250000
----------------------------------------------------------------------------------------------------
Epoch 1 | Batch 4 | Elapsed Time 82.820423
Loss 3.970361 | Loss_S 5.062921 | Loss_E 5.401593 | Loss_T 1.446568
 Acc 0.166667 |  Acc_S 0.187500 |  Acc_E 0.00000