## Load Data

We have already downloaded the open source Empathetic-dialogue dataset and stored them as .npy file in ./empathetic-dialogue. By import prepare_data_seq from ./utils/data_loader.py, we can have the well-prepared data ready for model training and testing. The sample data is shown below:

In [1]:
from utils.data_loader import prepare_data_seq
data_loader_tra, data_loader_val, data_loader_tst, vocab, program_number = prepare_data_seq(batch_size=16)

                                      Opts                                      
--------------------------------------------------------------------------------
                                dataset: empathetic                             
                             hidden_dim: 100                                    
                                emb_dim: 100                                    
                             batch_size: 16                                     
                                     lr: 0.0001                                 
                          max_grad_norm: 2.0                                    
                              beam_size: 5                                      
                                  model: experts                                
                        act_loss_weight: 0.001                                  
                                    hop: 6                                      
                            

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/wangshihang/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Building dataset...
Saved PICKLE
LOADING empathetic_dialogue
[situation]: i remember going to the fireworks with my best friend . there was a lot of people , but it only felt like us in the world .
[emotion]: sentimental
[context]: ['i remember going to see the fireworks with my best friend . it was the first time we ever spent time alone together . although there was a lot of people , we felt like the only people in the world .']
[target]: was this a friend you were in love with , or just a best friend ?
 
[situation]: i remember going to the fireworks with my best friend . there was a lot of people , but it only felt like us in the world .
[emotion]: sentimental
[context]: ['i remember going to see the fireworks with my best friend . it was the first time we ever spent time alone together . although there was a lot of people , we felt like the only people in the world .', 'was this a friend you were in love with , or just a best friend ?', 'this was a best friend . i miss her .']
[ta

## Model Initialization
Next, we initialized the model by calling Transformer_experts

In [2]:
from utils import config
from model.transformer_mulexpert import Transformer_experts
from model.common_layer import make_infinite, get_input_from_batch, get_output_from_batch
import tensorflow as tf
from copy import deepcopy
from tqdm import tqdm
import os
import math
import time
import datetime
import numpy as np 

np.random.seed(0)

best_ppl = 1000
check_iter = 2000
patient = 0

data_iter = make_infinite(data_loader_tra)
save_model_path ='./saved_model/moel'

if(config.model == "trs"):
    model = Transformer(vocab,decoder_number=program_number)
elif(config.model == "experts"):
    model = Transformer_experts(vocab,decoder_number=program_number)
print("MODEL USED:",config.model)

MODEL USED: experts


## Define Training Step

As this is a multi-task learning problem, i.e., the encoders conduct (1) the emotion classification task to determine which of the 32 emotion labels the input context should lies on, and (2) obtain the context representation as the input of decoders to generate desired response (basically another classification task with vocab_size of classes). Therefore, the loss should consist of two parts, the **cross entropy loss** of both tasks. Since the target label has been given, we use "SparseCategoricalCrossentropy" method here

In [3]:
uniq_cfg_name = datetime.datetime.now().strftime("%Y")
checkpoint_prefix = os.path.join(os.getcwd(), "checkpoints")
if not os.path.exists(checkpoint_prefix):#create a path to save the model checkpoint
    print("create model dir: %s" % checkpoint_prefix)
    os.mkdir(checkpoint_prefix)

checkpoint_path = os.path.join(checkpoint_prefix, uniq_cfg_name)
if os.path.exists(checkpoint_path):#if there exist a checkpoint file, read them though
    model.load_weights(checkpoint_path)
    print("load weight from: %s" % checkpoint_path)

criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)# the criteria for loss calculation
optimizer = tf.keras.optimizers.Adam(lr=config.lr)# the method for model optimization

training_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) #initializa method for training loss
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')#initializa method for training accuracy
testing_loss = tf.keras.metrics.Mean('validation_loss', dtype=tf.float32)#initializa method for testing loss
testing_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('validation_accuracy')#initializa method for testing accuracy

def train_step(input_x, training = True):
    dec_batch, _, _, _, _ = get_output_from_batch(input_x)# get true label of generated tokens
    if training:
        with tf.GradientTape() as tape:
            logit, logit_prob = model(input_x, training)# model forward, and return the logits for emotion classification task and response generation task
            train_loss = criterion(tf.reshape(dec_batch, -1), tf.reshape(logit, [-1, logit.shape[-1]])) + criterion(tf.cast(input_x[6], dtype=tf.int32), logit_prob)
        gradients = tape.gradient(train_loss, model.trainable_variables) # calculate gradient
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))# model optimization
        training_loss.update_state(train_loss)# update the training loss
        training_accuracy.update_state(input_x[6], logit_prob)# update the training accuracy
    else:
        logit, logit_prob = model(input_x, training)
        test_loss = criterion(tf.reshape(dec_batch, -1), tf.reshape(logit, [-1, logit.shape[-1]])) + criterion(tf.cast(input_x[6], dtype=tf.int32), logit_prob)
        testing_loss.update_state(test_loss)# update the testing loss
        testing_accuracy.update_state(input_x[6], logit_prob)# update the testing accuracy

create model dir: /Users/wangshihang/Documents/Columbia/E4040_Neural_Network/e4040-2021Fall-Project-MOEL-sl4640-sw3275-tz2372/checkpoints


## Model Training

Now we can train the model by calling train_step function with the following code:

In [4]:
try:
    training_loss.reset_states()
    training_accuracy.reset_states()
    #apply tqdm method to track the training progress, the training will stop when patient>2
    for n_iter in tqdm(range(1000000)):
        train_step(next(data_iter), training = True)        
        if((n_iter+1)%check_iter==0):#print the cur loss and accuracy after <check_iter> of runs
            print("[train iter %d] [%s]: %0.3f [%s]: %0.3f  [%s]: %0.3f" %  (n_iter+1, "loss", training_loss.result(), "ppl", math.exp(training_loss.result()), "emo_acc", training_accuracy.result()))
            training_loss.reset_states()
            training_accuracy.reset_states()

            testing_loss.reset_states()
            testing_accuracy.reset_states()
            pbar = tqdm(enumerate(data_loader_val), total=179)# valid size: 5734
            #validate the model performance on validation set
            for j, test_batch in pbar:
                train_step(test_batch, training = False)
            print("[test iter %d] [%s]: %0.3f [%s]: %0.3f  [%s]: %0.3f" %  (n_iter+1, "loss", testing_loss.result(), "ppl", math.exp(testing_loss.result()), "emo_acc", testing_accuracy.result()))
            print('-' * 89)
            
            if (config.model == "experts" and n_iter<10000):
                continue
            if math.exp(testing_loss.result()) < best_ppl:# if we get new <best_ppl>, update it and save the model
                #model.save(save_model_path)
                model.save_weights(save_model_path, overwrite=True)# save model
                patient = 0 #reset the patient
                best_ppl = math.exp(testing_loss.result())
            else:# if the patient doesn't update for three runs, then early stop
                patient += 1
            if patient > 3:
                break
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

  0%|          | 1999/1000000 [1:36:34<762:01:50,  2.75s/it] 

[train iter 2000] [loss]: 6.651 [ppl]: 773.367  [emo_acc]: 0.047



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:27,  1.50s/it][A
  1%|          | 2/179 [00:02<03:56,  1.34s/it][A
  2%|▏         | 3/179 [00:04<04:16,  1.46s/it][A
  2%|▏         | 4/179 [00:05<03:57,  1.36s/it][A
  3%|▎         | 5/179 [00:06<03:43,  1.28s/it][A
  3%|▎         | 6/179 [00:07<03:28,  1.21s/it][A
  4%|▍         | 7/179 [00:08<03:16,  1.14s/it][A
  4%|▍         | 8/179 [00:10<03:23,  1.19s/it][A
  5%|▌         | 9/179 [00:11<03:32,  1.25s/it][A
  6%|▌         | 10/179 [00:12<03:30,  1.25s/it][A
  6%|▌         | 11/179 [00:13<03:32,  1.27s/it][A
  7%|▋         | 12/179 [00:15<03:29,  1.25s/it][A
  7%|▋         | 13/179 [00:16<03:22,  1.22s/it][A
  8%|▊         | 14/179 [00:17<03:23,  1.23s/it][A
  8%|▊         | 15/179 [00:19<03:37,  1.33s/it][A
  9%|▉         | 16/179 [00:20<03:32,  1.30s/it][A
  9%|▉         | 17/179 [00:21<03:34,  1.32s/it][A
 10%|█         | 18/179 [00:22<03:24,  1.27s/it][A
 11%|█         | 19/179 [00:2

[test iter 2000] [loss]: 5.426 [ppl]: 227.278  [emo_acc]: 0.050
-----------------------------------------------------------------------------------------


  0%|          | 3999/1000000 [3:15:56<895:15:49,  3.24s/it]  

[train iter 4000] [loss]: 5.193 [ppl]: 180.038  [emo_acc]: 0.055



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:30,  1.52s/it][A
  1%|          | 2/179 [00:02<03:57,  1.34s/it][A
  2%|▏         | 3/179 [00:04<04:23,  1.50s/it][A
  2%|▏         | 4/179 [00:05<04:01,  1.38s/it][A
  3%|▎         | 5/179 [00:06<03:46,  1.30s/it][A
  3%|▎         | 6/179 [00:07<03:32,  1.23s/it][A
  4%|▍         | 7/179 [00:08<03:19,  1.16s/it][A
  4%|▍         | 8/179 [00:10<03:26,  1.21s/it][A
  5%|▌         | 9/179 [00:11<03:35,  1.27s/it][A
  6%|▌         | 10/179 [00:12<03:33,  1.27s/it][A
  6%|▌         | 11/179 [00:14<03:47,  1.35s/it][A
  7%|▋         | 12/179 [00:15<03:43,  1.34s/it][A
  7%|▋         | 13/179 [00:16<03:35,  1.30s/it][A
  8%|▊         | 14/179 [00:18<03:36,  1.32s/it][A
  8%|▊         | 15/179 [00:19<03:49,  1.40s/it][A
  9%|▉         | 16/179 [00:21<03:45,  1.38s/it][A
  9%|▉         | 17/179 [00:22<03:48,  1.41s/it][A
 10%|█         | 18/179 [00:23<03:38,  1.36s/it][A
 11%|█         | 19/179 [00:2

[test iter 4000] [loss]: 5.208 [ppl]: 182.784  [emo_acc]: 0.050
-----------------------------------------------------------------------------------------


  1%|          | 5999/1000000 [4:55:43<699:41:21,  2.53s/it]  

[train iter 6000] [loss]: 5.049 [ppl]: 155.863  [emo_acc]: 0.059



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:27,  1.50s/it][A
  1%|          | 2/179 [00:02<03:56,  1.34s/it][A
  2%|▏         | 3/179 [00:04<04:17,  1.46s/it][A
  2%|▏         | 4/179 [00:05<03:58,  1.36s/it][A
  3%|▎         | 5/179 [00:06<03:44,  1.29s/it][A
  3%|▎         | 6/179 [00:07<03:31,  1.22s/it][A
  4%|▍         | 7/179 [00:08<03:17,  1.15s/it][A
  4%|▍         | 8/179 [00:10<03:25,  1.20s/it][A
  5%|▌         | 9/179 [00:11<03:34,  1.26s/it][A
  6%|▌         | 10/179 [00:12<03:32,  1.26s/it][A
  6%|▌         | 11/179 [00:14<03:34,  1.28s/it][A
  7%|▋         | 12/179 [00:15<03:30,  1.26s/it][A
  7%|▋         | 13/179 [00:16<03:24,  1.23s/it][A
  8%|▊         | 14/179 [00:17<03:25,  1.24s/it][A
  8%|▊         | 15/179 [00:19<03:39,  1.34s/it][A
  9%|▉         | 16/179 [00:20<03:34,  1.31s/it][A
  9%|▉         | 17/179 [00:21<03:36,  1.34s/it][A
 10%|█         | 18/179 [00:23<03:26,  1.28s/it][A
 11%|█         | 19/179 [00:2

[test iter 6000] [loss]: 5.146 [ppl]: 171.762  [emo_acc]: 0.045
-----------------------------------------------------------------------------------------


  1%|          | 7999/1000000 [6:41:22<854:55:54,  3.10s/it]  

[train iter 8000] [loss]: 4.964 [ppl]: 143.137  [emo_acc]: 0.062



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:46,  1.61s/it][A
  1%|          | 2/179 [00:02<04:15,  1.44s/it][A
  2%|▏         | 3/179 [00:04<04:29,  1.53s/it][A
  2%|▏         | 4/179 [00:05<04:05,  1.41s/it][A
  3%|▎         | 5/179 [00:06<03:49,  1.32s/it][A
  3%|▎         | 6/179 [00:08<03:34,  1.24s/it][A
  4%|▍         | 7/179 [00:09<03:21,  1.17s/it][A
  4%|▍         | 8/179 [00:10<03:28,  1.22s/it][A
  5%|▌         | 9/179 [00:11<03:37,  1.28s/it][A
  6%|▌         | 10/179 [00:13<03:35,  1.27s/it][A
  6%|▌         | 11/179 [00:14<03:36,  1.29s/it][A
  7%|▋         | 12/179 [00:15<03:32,  1.27s/it][A
  7%|▋         | 13/179 [00:16<03:25,  1.24s/it][A
  8%|▊         | 14/179 [00:18<03:26,  1.25s/it][A
  8%|▊         | 15/179 [00:19<03:40,  1.34s/it][A
  9%|▉         | 16/179 [00:20<03:34,  1.31s/it][A
  9%|▉         | 17/179 [00:22<03:36,  1.34s/it][A
 10%|█         | 18/179 [00:23<03:29,  1.30s/it][A
 11%|█         | 19/179 [00:2

[test iter 8000] [loss]: 5.111 [ppl]: 165.838  [emo_acc]: 0.051
-----------------------------------------------------------------------------------------


  1%|          | 9999/1000000 [8:20:41<760:20:06,  2.76s/it]  

[train iter 10000] [loss]: 4.899 [ppl]: 134.127  [emo_acc]: 0.061



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:44,  1.60s/it][A
  1%|          | 2/179 [00:02<04:07,  1.40s/it][A
  2%|▏         | 3/179 [00:04<04:27,  1.52s/it][A
  2%|▏         | 4/179 [00:05<04:06,  1.41s/it][A
  3%|▎         | 5/179 [00:06<03:52,  1.33s/it][A
  3%|▎         | 6/179 [00:08<03:37,  1.26s/it][A
  4%|▍         | 7/179 [00:09<03:24,  1.19s/it][A
  4%|▍         | 8/179 [00:10<03:32,  1.25s/it][A
  5%|▌         | 9/179 [00:11<03:43,  1.31s/it][A
  6%|▌         | 10/179 [00:13<03:39,  1.30s/it][A
  6%|▌         | 11/179 [00:14<03:41,  1.32s/it][A
  7%|▋         | 12/179 [00:15<03:37,  1.30s/it][A
  7%|▋         | 13/179 [00:17<03:30,  1.27s/it][A
  8%|▊         | 14/179 [00:18<03:30,  1.28s/it][A
  8%|▊         | 15/179 [00:19<03:46,  1.38s/it][A
  9%|▉         | 16/179 [00:21<03:41,  1.36s/it][A
  9%|▉         | 17/179 [00:22<03:44,  1.39s/it][A
 10%|█         | 18/179 [00:23<03:33,  1.32s/it][A
 11%|█         | 19/179 [00:2

[test iter 10000] [loss]: 5.085 [ppl]: 161.511  [emo_acc]: 0.053
-----------------------------------------------------------------------------------------


  1%|          | 11999/1000000 [10:03:01<826:23:35,  3.01s/it] 

[train iter 12000] [loss]: 4.858 [ppl]: 128.800  [emo_acc]: 0.068



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:43,  1.59s/it][A
  1%|          | 2/179 [00:02<04:05,  1.39s/it][A
  2%|▏         | 3/179 [00:04<04:24,  1.50s/it][A
  2%|▏         | 4/179 [00:05<04:05,  1.40s/it][A
  3%|▎         | 5/179 [00:06<03:50,  1.33s/it][A
  3%|▎         | 6/179 [00:08<03:37,  1.26s/it][A
  4%|▍         | 7/179 [00:09<03:23,  1.19s/it][A
  4%|▍         | 8/179 [00:10<03:31,  1.24s/it][A
  5%|▌         | 9/179 [00:11<03:39,  1.29s/it][A
  6%|▌         | 10/179 [00:13<03:37,  1.29s/it][A
  6%|▌         | 11/179 [00:14<03:39,  1.30s/it][A
  7%|▋         | 12/179 [00:15<03:35,  1.29s/it][A
  7%|▋         | 13/179 [00:16<03:29,  1.26s/it][A
  8%|▊         | 14/179 [00:18<03:28,  1.27s/it][A
  8%|▊         | 15/179 [00:19<03:42,  1.36s/it][A
  9%|▉         | 16/179 [00:21<03:37,  1.33s/it][A
  9%|▉         | 17/179 [00:22<03:40,  1.36s/it][A
 10%|█         | 18/179 [00:23<03:30,  1.31s/it][A
 11%|█         | 19/179 [00:2

[test iter 12000] [loss]: 5.061 [ppl]: 157.810  [emo_acc]: 0.055
-----------------------------------------------------------------------------------------


  1%|▏         | 13999/1000000 [11:44:22<970:26:10,  3.54s/it]  

[train iter 14000] [loss]: 4.813 [ppl]: 123.081  [emo_acc]: 0.066



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:52,  1.64s/it][A
  1%|          | 2/179 [00:02<04:12,  1.43s/it][A
  2%|▏         | 3/179 [00:04<04:30,  1.54s/it][A
  2%|▏         | 4/179 [00:05<04:10,  1.43s/it][A
  3%|▎         | 5/179 [00:07<03:55,  1.35s/it][A
  3%|▎         | 6/179 [00:08<03:41,  1.28s/it][A
  4%|▍         | 7/179 [00:09<03:28,  1.21s/it][A
  4%|▍         | 8/179 [00:10<03:35,  1.26s/it][A
  5%|▌         | 9/179 [00:12<03:44,  1.32s/it][A
  6%|▌         | 10/179 [00:13<03:43,  1.32s/it][A
  6%|▌         | 11/179 [00:14<03:44,  1.34s/it][A
  7%|▋         | 12/179 [00:16<03:40,  1.32s/it][A
  7%|▋         | 13/179 [00:17<03:36,  1.30s/it][A
  8%|▊         | 14/179 [00:18<03:35,  1.31s/it][A
  8%|▊         | 15/179 [00:20<03:50,  1.41s/it][A
  9%|▉         | 16/179 [00:21<03:43,  1.37s/it][A
  9%|▉         | 17/179 [00:23<03:45,  1.39s/it][A
 10%|█         | 18/179 [00:24<03:34,  1.34s/it][A
 11%|█         | 19/179 [00:2

[test iter 14000] [loss]: 5.061 [ppl]: 157.726  [emo_acc]: 0.056
-----------------------------------------------------------------------------------------


  2%|▏         | 15999/1000000 [13:26:12<739:03:54,  2.70s/it]  

[train iter 16000] [loss]: 4.784 [ppl]: 119.620  [emo_acc]: 0.071



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:39,  1.57s/it][A
  1%|          | 2/179 [00:02<04:03,  1.37s/it][A
  2%|▏         | 3/179 [00:04<04:23,  1.50s/it][A
  2%|▏         | 4/179 [00:05<04:02,  1.39s/it][A
  3%|▎         | 5/179 [00:06<03:48,  1.31s/it][A
  3%|▎         | 6/179 [00:07<03:33,  1.24s/it][A
  4%|▍         | 7/179 [00:08<03:21,  1.17s/it][A
  4%|▍         | 8/179 [00:10<03:29,  1.22s/it][A
  5%|▌         | 9/179 [00:11<03:39,  1.29s/it][A
  6%|▌         | 10/179 [00:12<03:36,  1.28s/it][A
  6%|▌         | 11/179 [00:14<03:38,  1.30s/it][A
  7%|▋         | 12/179 [00:15<03:33,  1.28s/it][A
  7%|▋         | 13/179 [00:16<03:27,  1.25s/it][A
  8%|▊         | 14/179 [00:18<03:27,  1.26s/it][A
  8%|▊         | 15/179 [00:19<03:42,  1.36s/it][A
  9%|▉         | 16/179 [00:20<03:36,  1.33s/it][A
  9%|▉         | 17/179 [00:22<03:40,  1.36s/it][A
 10%|█         | 18/179 [00:23<03:28,  1.30s/it][A
 11%|█         | 19/179 [00:2

[test iter 16000] [loss]: 5.053 [ppl]: 156.430  [emo_acc]: 0.054
-----------------------------------------------------------------------------------------


  2%|▏         | 17999/1000000 [15:07:07<802:59:09,  2.94s/it]  

[train iter 18000] [loss]: 4.756 [ppl]: 116.334  [emo_acc]: 0.071



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<05:09,  1.74s/it][A
  1%|          | 2/179 [00:03<04:20,  1.47s/it][A
  2%|▏         | 3/179 [00:04<04:35,  1.57s/it][A
  2%|▏         | 4/179 [00:05<04:14,  1.46s/it][A
  3%|▎         | 5/179 [00:07<04:00,  1.38s/it][A
  3%|▎         | 6/179 [00:08<03:45,  1.30s/it][A
  4%|▍         | 7/179 [00:09<03:31,  1.23s/it][A
  4%|▍         | 8/179 [00:10<03:36,  1.27s/it][A
  5%|▌         | 9/179 [00:12<03:46,  1.33s/it][A
  6%|▌         | 10/179 [00:13<03:44,  1.33s/it][A
  6%|▌         | 11/179 [00:15<03:48,  1.36s/it][A
  7%|▋         | 12/179 [00:16<03:44,  1.35s/it][A
  7%|▋         | 13/179 [00:17<03:37,  1.31s/it][A
  8%|▊         | 14/179 [00:18<03:36,  1.31s/it][A
  8%|▊         | 15/179 [00:20<03:49,  1.40s/it][A
  9%|▉         | 16/179 [00:21<03:43,  1.37s/it][A
  9%|▉         | 17/179 [00:23<03:50,  1.42s/it][A
 10%|█         | 18/179 [00:24<03:38,  1.36s/it][A
 11%|█         | 19/179 [00:2

[test iter 18000] [loss]: 5.060 [ppl]: 157.567  [emo_acc]: 0.057
-----------------------------------------------------------------------------------------


  2%|▏         | 19999/1000000 [16:48:17<754:00:37,  2.77s/it]  

[train iter 20000] [loss]: 4.729 [ppl]: 113.208  [emo_acc]: 0.070



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:46,  1.61s/it][A
  1%|          | 2/179 [00:02<04:07,  1.40s/it][A
  2%|▏         | 3/179 [00:04<04:27,  1.52s/it][A
  2%|▏         | 4/179 [00:05<04:06,  1.41s/it][A
  3%|▎         | 5/179 [00:06<03:51,  1.33s/it][A
  3%|▎         | 6/179 [00:08<03:36,  1.25s/it][A
  4%|▍         | 7/179 [00:09<03:23,  1.18s/it][A
  4%|▍         | 8/179 [00:10<03:31,  1.24s/it][A
  5%|▌         | 9/179 [00:11<03:41,  1.30s/it][A
  6%|▌         | 10/179 [00:13<03:38,  1.29s/it][A
  6%|▌         | 11/179 [00:14<03:41,  1.32s/it][A
  7%|▋         | 12/179 [00:15<03:37,  1.30s/it][A
  7%|▋         | 13/179 [00:16<03:30,  1.27s/it][A
  8%|▊         | 14/179 [00:18<03:30,  1.28s/it][A
  8%|▊         | 15/179 [00:19<03:46,  1.38s/it][A
  9%|▉         | 16/179 [00:21<03:40,  1.35s/it][A
  9%|▉         | 17/179 [00:22<03:43,  1.38s/it][A
 10%|█         | 18/179 [00:23<03:34,  1.33s/it][A
 11%|█         | 19/179 [00:2

[test iter 20000] [loss]: 5.061 [ppl]: 157.779  [emo_acc]: 0.059
-----------------------------------------------------------------------------------------


  2%|▏         | 21999/1000000 [18:33:50<843:20:11,  3.10s/it]  

[train iter 22000] [loss]: 4.697 [ppl]: 109.580  [emo_acc]: 0.076



  0%|          | 0/179 [00:00<?, ?it/s][A
  1%|          | 1/179 [00:01<04:49,  1.63s/it][A
  1%|          | 2/179 [00:02<04:12,  1.43s/it][A
  2%|▏         | 3/179 [00:04<04:30,  1.54s/it][A
  2%|▏         | 4/179 [00:05<04:10,  1.43s/it][A
  3%|▎         | 5/179 [00:07<03:57,  1.37s/it][A
  3%|▎         | 6/179 [00:08<03:42,  1.29s/it][A
  4%|▍         | 7/179 [00:09<03:28,  1.21s/it][A
  4%|▍         | 8/179 [00:10<03:37,  1.27s/it][A
  5%|▌         | 9/179 [00:12<03:47,  1.34s/it][A
  6%|▌         | 10/179 [00:13<03:44,  1.33s/it][A
  6%|▌         | 11/179 [00:14<03:45,  1.34s/it][A
  7%|▋         | 12/179 [00:16<03:41,  1.32s/it][A
  7%|▋         | 13/179 [00:17<03:34,  1.29s/it][A
  8%|▊         | 14/179 [00:18<03:37,  1.32s/it][A
  8%|▊         | 15/179 [00:20<03:51,  1.41s/it][A
  9%|▉         | 16/179 [00:21<03:45,  1.39s/it][A
  9%|▉         | 17/179 [00:23<03:50,  1.42s/it][A
 10%|█         | 18/179 [00:24<03:37,  1.35s/it][A
 11%|█         | 19/179 [00:2

[test iter 22000] [loss]: 5.060 [ppl]: 157.655  [emo_acc]: 0.055
-----------------------------------------------------------------------------------------


  2%|▏         | 22004/1000000 [18:38:14<828:21:52,  3.05s/it]  

-----------------------------------------------------------------------------------------
Exiting from training early





## Model Testing

We can also load the well-trained model and test its performance via the following code: 

In [5]:
# Initialize model
if(config.model == "trs"):
    model = Transformer(vocab,decoder_number=program_number)
elif(config.model == "experts"):
    model = Transformer_experts(vocab,decoder_number=program_number)
# load model weights from path
model.load_weights(save_model_path)
print("Testing MODEL:",config.model)

#Define the method to evaluate the model performance, i.e., accuracy & cross entropy loss
testing_loss = tf.keras.metrics.Mean('validation_loss', dtype=tf.float32)
testing_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('validation_accuracy')
criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

#apply tqdm method to track the testing progress
pbar = tqdm(enumerate(data_loader_tst), total=164)# test size: 5255
try:
    for j, test_batch in pbar:
        dec_batch, _, _, _, _ = get_output_from_batch(test_batch)#get true token labels of testing data
        logit, logit_prob = model(test_batch, training=False)
        test_loss = criterion(tf.reshape(dec_batch, -1), tf.reshape(logit, [-1, logit.shape[-1]])) + criterion(tf.cast(test_batch[6], dtype=tf.int32), logit_prob)
        #test_pred_program = np.argmax(logit_prob.numpy(), axis=1)
        testing_loss.update_state(test_loss)
        testing_accuracy.update_state(test_batch[6], logit_prob)
    print("[test metrics] [%s]: %0.3f [%s]: %0.3f  [%s]: %0.3f" %  ("loss", testing_loss.result(), "ppl", math.exp(testing_loss.result()), "emo_acc", testing_accuracy.result()))
except:
    print('-' * 89)
    print('Exiting from training early')

Testing MODEL: experts


100%|██████████| 164/164 [03:35<00:00,  1.32s/it]

[test metrics] [loss]: 4.990 [ppl]: 146.904  [emo_acc]: 0.051





## Baseline Model

We also realized the Transformer baseline in ./model, and one can call it as baseline by changing the **config.model** argument in ./utils/config.py to "trs", and then run the same code as above