In [1]:
import os
import pickle
import argparse
import numpy as np
from model import Options, Seq2SeqAttn


In [2]:
# Parse the command line arguments.
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type = str, default = '../pre-data/',
                    help = 'the directory to the data')

parser.add_argument('--word_embeddings_path', type = str, default = '../pre-data/word_embeddings.npy',
                    help = 'the directory to the pre-trained word embeddings')
parser.add_argument('--VAD_path', type = str, default = '../pre-data/VAD.npy',
                    help = 'the directory to VAD')
parser.add_argument('--tf_path', type = str, default = '../pre-data/tf.npy',
                    help = 'the directory to term frequency')
parser.add_argument('--VAD_loss_path', type = str, default = '../pre-data/VAD_loss.npy',
                    help = 'the directory to VAD loss for each word')
parser.add_argument('--ti_path', type = str, default = '../pre-data/mu_li.npy',
                    help = 'the directory to term importance')

parser.add_argument('--num_epochs', type = int, default = 5,
                    help = 'the number of epochs to train the data')
parser.add_argument('--batch_size', type = int, default = 64,
                    help = 'the batch size')
parser.add_argument('--learning_rate', type = float, default = 0.0001,
                    help = 'the learning rate')
parser.add_argument('--beam_width', type = int, default = 32,
                    help = 'the beam width when decoding')
parser.add_argument('--word_embed_size', type = int, default = 256,
                    help = 'the size of word embeddings')
parser.add_argument('--n_hidden_units_enc', type = int, default = 256,
                    help = 'the number of hidden units of encoder')
parser.add_argument('--n_hidden_units_dec', type = int, default = 256,
                    help = 'the number of hidden units of decoder')
# ? attn_depth
parser.add_argument('--attn_depth', type = int, default = 128,
                    help = 'attention depth')

parser.add_argument('--restore_path_TS', type = str, default = '../model_dailydialog_rf/model_TS',
                    help = 'the path to restore the trained model')
parser.add_argument('--save_path_TS', type = str, default = '../model_dailydialog_rf/model_TS',
                    help = 'the path to save the trained model to')

parser.add_argument('--restore_path_ST', type = str, default = '../model_dailydialog_rf/model_ST',
                    help = 'the path to restore the trained model')
parser.add_argument('--save_path_ST', type = str, default = '../model_dailydialog_rf/model_ST',
                    help = 'the path to save the trained model to')

parser.add_argument('--restore_epoch', type = int, default = 0,
                    help = 'the epoch to restore')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()


In [3]:
def read_data(data_path):
    def load_np_files(path):
        my_set = {}
        my_set['enc_input'] = np.load(os.path.join(path, 'enc_input.npy'))
        my_set['dec_input'] = np.load(os.path.join(path, 'dec_input.npy'))
        my_set['target'] = np.load(os.path.join(path, 'target.npy'))
        my_set['enc_input_len'] = np.load(os.path.join(path, 'enc_input_len.npy'))
        my_set['dec_input_len'] = np.load(os.path.join(path, 'dec_input_len.npy'))
        # to check if or not to complete the last batch
        idx = np.arange(my_set['dec_input'].shape[0])
        left_samples = idx[-1]%args.batch_size
        if left_samples:
            last_batch_idx = np.random.randint(0,idx[-1]-left_samples,size = args.batch_size - left_samples - 1)
            idx = np.concatenate([idx,last_batch_idx])
            
            my_set['enc_input'] = my_set['enc_input'][idx]
            my_set['dec_input'] = my_set['dec_input'][idx]
            my_set['target'] = my_set['target'][idx]
            my_set['enc_input_len'] = my_set['enc_input_len'][idx]
            my_set['dec_input_len'] = my_set['dec_input_len'][idx]
        return my_set
    train_set = load_np_files(os.path.join(data_path, 'train'))
    valid_set = load_np_files(os.path.join(data_path, 'validation'))
    
    with open(os.path.join(data_path, 'token2id.pickle'), 'rb') as file:
        token2id = pickle.load(file)
    with open(os.path.join(data_path, 'id2token.pickle'), 'rb') as file:
        id2token = pickle.load(file)

    return train_set, valid_set, token2id,id2token

---
Train model maximizing P(T|S)

In [None]:
if __name__ == '__main__':
    train_set, valid_set, token2id,id2token = read_data(args.data_path)
#     train_set['enc_input'] = train_set['enc_input'][:128,]
    
    max_uttr_len_enc = train_set['enc_input'].shape[1]
    max_uttr_len_dec = train_set['dec_input'].shape[1]

    word_embeddings = np.load(args.word_embeddings_path)
    VAD = np.load(args.VAD_path)
    termfreq = np.load(args.ti_path) # term importance
    termfreq = termfreq.reshape(-1,1)
    VAD_loss = np.load(args.VAD_loss_path)
    VAD_loss = VAD_loss.reshape(-1,1)
    
    options = Options(mode = 'TRAIN',
                      num_epochs = args.num_epochs,
                      batch_size = args.batch_size,
                      learning_rate = args.learning_rate,
                      beam_width = args.beam_width,
                      corpus_size = len(token2id),
                      max_uttr_len_enc = max_uttr_len_enc,
                      max_uttr_len_dec = max_uttr_len_dec,
                      go_index = token2id['<go>'],
                      eos_index = token2id['<eos>'],
                      word_embed_size = args.word_embed_size,
                      n_hidden_units_enc = args.n_hidden_units_enc,
                      n_hidden_units_dec = args.n_hidden_units_dec,
                      attn_depth = args.attn_depth,
                      word_embeddings = word_embeddings)
    model_TS = Seq2SeqAttn(options)

    for var in model_TS.tvars:
        print(var.name)

    if args.restore_epoch > 0:
        model_TS.restore(os.path.join(args.restore_path_TS, 'model_TS_epoch_{:03d}.ckpt'.format(args.restore_epoch)))
    else:
        model_TS.init_tf_vars()
    model_TS.train(train_set, VAD,termfreq, VAD_loss,args.save_path_TS, args.restore_epoch, valid_set)

Building the TensorFlow graph...
embedding/embedding:0
encoding/rnn/gru_cell/gates/kernel:0
encoding/rnn/gru_cell/gates/bias:0
encoding/rnn/gru_cell/candidate/kernel:0
encoding/rnn/gru_cell/candidate/bias:0
decoding/memory_layer/kernel:0
decoding/attention_v:0
decoding/my_bahdanau_attention/query_layer/kernel:0
decoding/my_bahdanau_attention/attention_Wb/kernel:0
decoding/attention_wrapper/gru_cell/gates/kernel:0
decoding/attention_wrapper/gru_cell/gates/bias:0
decoding/attention_wrapper/gru_cell/candidate/kernel:0
decoding/attention_wrapper/gru_cell/candidate/bias:0
decoding/dense/kernel:0
decoding/dense/bias:0
TensorFlow variables initialized.
Start to train the model...
Epoch 001/005, valid ppl = None, batch 0001/0226, train loss = 8.088951110839844
Epoch 001/005, valid ppl = None, batch 0002/0226, train loss = 8.083484649658203
Epoch 001/005, valid ppl = None, batch 0003/0226, train loss = 8.062005043029785
Epoch 001/005, valid ppl = None, batch 0004/0226, train loss = 8.0533742904

Epoch 001/005, valid ppl = None, batch 0094/0226, train loss = 5.35126256942749
Epoch 001/005, valid ppl = None, batch 0095/0226, train loss = 5.356368541717529
Epoch 001/005, valid ppl = None, batch 0096/0226, train loss = 5.20646858215332
Epoch 001/005, valid ppl = None, batch 0097/0226, train loss = 5.4400954246521
Epoch 001/005, valid ppl = None, batch 0098/0226, train loss = 5.313283920288086
Epoch 001/005, valid ppl = None, batch 0099/0226, train loss = 5.256808757781982
Epoch 001/005, valid ppl = None, batch 0100/0226, train loss = 5.280642032623291
Epoch 001/005, valid ppl = None, batch 0101/0226, train loss = 5.141192436218262
Epoch 001/005, valid ppl = None, batch 0102/0226, train loss = 5.347930431365967
Epoch 001/005, valid ppl = None, batch 0103/0226, train loss = 5.280808448791504
Epoch 001/005, valid ppl = None, batch 0104/0226, train loss = 5.230621337890625
Epoch 001/005, valid ppl = None, batch 0105/0226, train loss = 5.293386459350586
Epoch 001/005, valid ppl = None,

Epoch 001/005, valid ppl = None, batch 0196/0226, train loss = 4.8175225257873535
Epoch 001/005, valid ppl = None, batch 0197/0226, train loss = 4.835391044616699
Epoch 001/005, valid ppl = None, batch 0198/0226, train loss = 4.809834003448486
Epoch 001/005, valid ppl = None, batch 0199/0226, train loss = 4.967055797576904
Epoch 001/005, valid ppl = None, batch 0200/0226, train loss = 4.860891819000244
Epoch 001/005, valid ppl = None, batch 0201/0226, train loss = 5.029396057128906
Epoch 001/005, valid ppl = None, batch 0202/0226, train loss = 4.871562480926514
Epoch 001/005, valid ppl = None, batch 0203/0226, train loss = 4.971390247344971
Epoch 001/005, valid ppl = None, batch 0204/0226, train loss = 4.822817325592041
Epoch 001/005, valid ppl = None, batch 0205/0226, train loss = 4.6669087409973145
Epoch 001/005, valid ppl = None, batch 0206/0226, train loss = 4.762619495391846
Epoch 001/005, valid ppl = None, batch 0207/0226, train loss = 4.831693649291992
Epoch 001/005, valid ppl =

Epoch 002/005, valid ppl = 117.95818713098286, batch 0060/0226, train loss = 4.865610122680664
Epoch 002/005, valid ppl = 117.95818713098286, batch 0061/0226, train loss = 4.695831775665283
Epoch 002/005, valid ppl = 117.95818713098286, batch 0062/0226, train loss = 4.681426048278809
Epoch 002/005, valid ppl = 117.95818713098286, batch 0063/0226, train loss = 4.677123546600342
Epoch 002/005, valid ppl = 117.95818713098286, batch 0064/0226, train loss = 4.523500442504883
Epoch 002/005, valid ppl = 117.95818713098286, batch 0065/0226, train loss = 4.644879341125488
Epoch 002/005, valid ppl = 117.95818713098286, batch 0066/0226, train loss = 4.444182872772217
Epoch 002/005, valid ppl = 117.95818713098286, batch 0067/0226, train loss = 4.511797904968262
Epoch 002/005, valid ppl = 117.95818713098286, batch 0068/0226, train loss = 4.646875381469727
Epoch 002/005, valid ppl = 117.95818713098286, batch 0069/0226, train loss = 4.700347900390625
Epoch 002/005, valid ppl = 117.95818713098286, bat

Epoch 002/005, valid ppl = 117.95818713098286, batch 0147/0226, train loss = 4.598510265350342
Epoch 002/005, valid ppl = 117.95818713098286, batch 0148/0226, train loss = 4.59454870223999
Epoch 002/005, valid ppl = 117.95818713098286, batch 0149/0226, train loss = 4.453695297241211
Epoch 002/005, valid ppl = 117.95818713098286, batch 0150/0226, train loss = 4.5502119064331055
Epoch 002/005, valid ppl = 117.95818713098286, batch 0151/0226, train loss = 4.471029281616211
Epoch 002/005, valid ppl = 117.95818713098286, batch 0152/0226, train loss = 4.533938407897949
Epoch 002/005, valid ppl = 117.95818713098286, batch 0153/0226, train loss = 4.411929130554199
Epoch 002/005, valid ppl = 117.95818713098286, batch 0154/0226, train loss = 4.509737968444824
Epoch 002/005, valid ppl = 117.95818713098286, batch 0155/0226, train loss = 4.575893402099609
Epoch 002/005, valid ppl = 117.95818713098286, batch 0156/0226, train loss = 4.5748467445373535
Epoch 002/005, valid ppl = 117.95818713098286, ba

Epoch 003/005, valid ppl = 83.83633605258021, batch 0007/0226, train loss = 4.269045352935791
Epoch 003/005, valid ppl = 83.83633605258021, batch 0008/0226, train loss = 4.268401622772217
Epoch 003/005, valid ppl = 83.83633605258021, batch 0009/0226, train loss = 4.319323539733887
Epoch 003/005, valid ppl = 83.83633605258021, batch 0010/0226, train loss = 4.469695091247559
Epoch 003/005, valid ppl = 83.83633605258021, batch 0011/0226, train loss = 4.298798084259033
Epoch 003/005, valid ppl = 83.83633605258021, batch 0012/0226, train loss = 4.149795055389404
Epoch 003/005, valid ppl = 83.83633605258021, batch 0013/0226, train loss = 4.16184139251709
Epoch 003/005, valid ppl = 83.83633605258021, batch 0014/0226, train loss = 4.354247570037842
Epoch 003/005, valid ppl = 83.83633605258021, batch 0015/0226, train loss = 4.483072280883789
Epoch 003/005, valid ppl = 83.83633605258021, batch 0016/0226, train loss = 4.58226203918457
Epoch 003/005, valid ppl = 83.83633605258021, batch 0017/0226,

Epoch 003/005, valid ppl = 83.83633605258021, batch 0095/0226, train loss = 4.371263027191162
Epoch 003/005, valid ppl = 83.83633605258021, batch 0096/0226, train loss = 4.300204277038574
Epoch 003/005, valid ppl = 83.83633605258021, batch 0097/0226, train loss = 4.212101459503174
Epoch 003/005, valid ppl = 83.83633605258021, batch 0098/0226, train loss = 4.426673412322998
Epoch 003/005, valid ppl = 83.83633605258021, batch 0099/0226, train loss = 4.291828155517578
Epoch 003/005, valid ppl = 83.83633605258021, batch 0100/0226, train loss = 4.411757946014404
Epoch 003/005, valid ppl = 83.83633605258021, batch 0101/0226, train loss = 4.454925537109375
Epoch 003/005, valid ppl = 83.83633605258021, batch 0102/0226, train loss = 4.405818939208984
Epoch 003/005, valid ppl = 83.83633605258021, batch 0103/0226, train loss = 4.184712886810303
Epoch 003/005, valid ppl = 83.83633605258021, batch 0104/0226, train loss = 4.3431620597839355
Epoch 003/005, valid ppl = 83.83633605258021, batch 0105/02

---
Train model P(S|T)

In [None]:
def revert(myset):
    enc_input = myset['dec_input'][:,1:]
    dec_input =  np.insert(myset['enc_input'], 0, token2id['<go>'], axis=1) # add <go> in the beginning of decoder

    target = np.insert(myset['enc_input'], -1, 0, axis=1) 
    tmp_idx = [np.where(s==0)[0][0] for s in target] 
    target[np.arange(target.shape[0]),tmp_idx] = token2id['<eos>'] # add <eos> at the end of decoder
    
    newset = {}
    
    newset['enc_input'] = enc_input
    newset['dec_input'] = dec_input
    newset['target'] = target
    newset['enc_input_len'] = myset['dec_input_len']
    newset['dec_input_len'] = myset['enc_input_len']
    return newset

In [None]:
if __name__ == '__main__':
    train_set, valid_set, token2id,id2token = read_data(args.data_path)
    train_set = revert(train_set)
    valid_set = revert(valid_set)
#     train_set['enc_input'] = train_set['enc_input'][:128,]
    
    max_uttr_len_enc = train_set['enc_input'].shape[1]
    max_uttr_len_dec = train_set['dec_input'].shape[1]

    word_embeddings = np.load(args.word_embeddings_path)
    VAD = np.load(args.VAD_path)
    termfreq = np.load(args.ti_path) # term importance
    termfreq = termfreq.reshape(-1,1)
    VAD_loss = np.load(args.VAD_loss_path)
    VAD_loss = VAD_loss.reshape(-1,1)
    
    options = Options(mode = 'TRAIN',
                      num_epochs = args.num_epochs,
                      batch_size = args.batch_size,
                      learning_rate = args.learning_rate,
                      beam_width = args.beam_width,
                      corpus_size = len(token2id),
                      max_uttr_len_enc = max_uttr_len_enc,
                      max_uttr_len_dec = max_uttr_len_dec,
                      go_index = token2id['<go>'],
                      eos_index = token2id['<eos>'],
                      word_embed_size = args.word_embed_size,
                      n_hidden_units_enc = args.n_hidden_units_enc,
                      n_hidden_units_dec = args.n_hidden_units_dec,
                      attn_depth = args.attn_depth,
                      word_embeddings = word_embeddings)
    model_ST = Seq2SeqAttn(options)

    for var in model_TS.tvars:
        print(var.name)

    if args.restore_epoch > 0:
        model_ST.restore(os.path.join(args.restore_path_ST, 'model_TS_epoch_{:03d}.ckpt'.format(args.restore_epoch)))
    else:
        model_ST.init_tf_vars()
    model_ST.train(train_set, VAD,termfreq, VAD_loss,args.save_path_ST, args.restore_epoch, valid_set)