In [1]:
import torch
import argparse
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from argparse import ArgumentParser
import yaml
import os

from transformer.Models import Transformer
from transformer.Translator import Translator
from utils import same_seeds

In [4]:
# parse argument
parser = ArgumentParser()
parser.add_argument("--config_path", dest="config_path",
                    default='../configs/dpng_transformer_bert_tokenizer_bow_indivtopk.yaml')
parser.add_argument("--seed", dest="seed", default=0, type=int)

args = parser.parse_args()
config_path = args.config_path
seed = args.seed
print("config_path:", config_path)
print("seed: ", seed)

usage: ipykernel_launcher.py [-h] [--config_path CONFIG_PATH]
ipykernel_launcher.py: error: unrecognized arguments: -f /shared_home/r08922168/.local/share/jupyter/runtime/kernel-dedc5aa2-89a1-4bfe-8719-a074b031da8f.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# fix seed
same_seeds(seed)

In [5]:
##### Read Arguments from Config File #####

# config_path = '../configs/base_transformer.yaml'
# config_path = '../configs/dpng_transformer.yaml'
# config_path = '../configs/dpng_transformer_bert_tokenizer.yaml'
# config_path = '../configs/dpng_transformer_bert_tokenizer_bow.yaml'
# config_path = '../configs/dpng_transformer_bert_tokenizer_bow_indivtopk.yaml'
# config_path = '../configs/dpng_transformer_bert_tokenizer_bow_indivtopk_onlybow.yaml'
# config_path = '../configs/dpng_transformer_wordnet.yaml'

with open(config_path) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    print(config)

save_model_path = config['save_model_path']
output_file = config['test_output_file']
use_dataset = config['dataset']

batch_size = 1

d_model = config['d_model']
d_inner_hid = config['d_inner_hid']
d_k = config['d_k']
d_v = config['d_v']

n_head = config['n_head']
n_layers = config['n_layers']
n_warmup_steps = config['n_warmup_steps']

dropout = config['dropout']
embs_share_weight = config['embs_share_weight']
proj_share_weight = config['proj_share_weight']
label_smoothing = config['label_smoothing']

train_size = config['train_size']
val_size = config['val_size']
test_size = config['test_size']

beam_size = 3
max_seq_len = 30

try:
    is_bow = config['is_bow']

    if is_bow:
        bow_strategy = config['bow_strategy']
        topk = config['topk']
        only_bow = config['only_bow']
        if bow_strategy != 'simple_sum':
            indiv_topk = config['indiv_topk']
        else:
            # not used but use default value for simplicity
            indiv_topk = 50

except KeyError:
    is_bow = False
    
try:
    use_wordnet = config['use_wordnet']
    indiv_k = config['indiv_k']
    replace_origin = config['replace_origin']
    append_bow = config['append_bow']
except KeyError:
    use_wordnet = False
    
# ###################

{'save_model_path': '../models/DNPG_base_transformer_wordnet.pth', 'log_file': '../logs/DNPG_base_transformer_wordnet_training.txt', 'test_output_file': '../outputs/test_DNPG_base_transformer_wordnet.txt', 'val_output_file': '../outputs/val_DNPG_base_transformer_wordnet.txt', 'dataset': 'quora_wordnet_dataset', 'num_epochs': 50, 'batch_size': 128, 'd_model': 450, 'd_inner_hid': 512, 'd_k': 50, 'd_v': 50, 'n_head': 9, 'n_layers': 3, 'n_warmup_steps': 12000, 'dropout': 0.1, 'embs_share_weight': True, 'proj_share_weight': True, 'label_smoothing': False, 'train_size': 100000, 'val_size': 4000, 'test_size': 20000, 'is_bow': False, 'lr': '1e-3', 'use_wordnet': True, 'indiv_k': 5, 'replace_origin': False}


In [None]:
# set model and log path
seed_model_root = '../models/fixseed/seed{}/'.format(seed)
seed_output_root = '../outputs/fixseed/seed{}/'.format(seed)

if not os.path.exists(seed_model_root):
    os.makedirs(seed_model_root)

if not os.path.exists(seed_output_root):
    os.makedirs(seed_output_root)

save_model_path = seed_model_root + save_model_path.split('/')[-1]
output_file = seed_output_root + output_file.split('/')[-1]
print('seed: ', seed)
print('save model path: ', save_model_path)
print('output file: ', output_file)

In [3]:
# ############### Arguments ###############
# # The argument is same for DNPG paper
# save_model_path = '../models/DNPG_base_transformer.pth'
# output_file = '../outputs/test_DNPG_transformer_out.txt'
# max_seq_len = 30

# batch_size = 1
# beam_size = 3 # 1 for greedy

# d_model = 450
# d_inner_hid = 512
# d_k = 64
# d_v = 64

# n_head = 9
# n_layers = 3
# dropout = 0.1
# embs_share_weight = True
# proj_share_weight = True

# train_size = 100000
# val_size = 4000
# test_size = 20000
# #######################################

In [4]:
# ##### Arguments #####
# d_model = 512
# save_model_path = '../models/base_transformer.pth'
# batch_size = 1
# beam_size = 3 # 1 for greedy
# max_seq_len = 30
# output_file = '../outputs/val_transformer_out.txt'

# d_model = 512
# d_inner_hid = 512
# d_k = 64
# d_v = 64

# n_head = 8
# n_layers = 6
# dropout = 0.1
# embs_share_weight = True
# proj_share_weight = True
# ###################

In [6]:
# load dataset
from datasets.quora_text_dataset import QuoraTextDataset

if use_dataset == 'quora_dataset':
    from datasets.quora_dataset import QuoraDataset as Dataset
elif use_dataset == 'quora_wordnet_dataset':
    from datasets.quora_wordnet_dataset import QuoraWordnetDataset as Dataset
elif use_dataset == 'quora_wordnet_aug_dataset':
    from datasets.quora_wordnet_aug_dataset import QuoraWordnetAugDataset as Dataset        
else:
    raise NotImplementedError("Dataset is not defined or not implemented: {}".format(use_dataset))

In [7]:
def create_mini_batch(samples):
    seq1_tensors = [s[0] for s in samples]
    seq2_tensors = [s[1] for s in samples]

    # zero pad
    seq1_tensors = pad_sequence(seq1_tensors,
                                  batch_first=True)

    seq2_tensors = pad_sequence(seq2_tensors,
                                  batch_first=True)    
    
    return seq1_tensors, seq2_tensors

if use_wordnet:
    dataset = Dataset("test", train_size, val_size, test_size, indiv_k=indiv_k, replace_origin=False, append_bow=append_bow)
else:
    dataset = Dataset("test", train_size, val_size, test_size)
same_seeds(seed)    
text_dataset = QuoraTextDataset("test", train_size, val_size, test_size)
data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=create_mini_batch, shuffle=False)

100%|██████████| 20000/20000 [00:00<00:00, 86411.37it/s]

[Info] Loading the Dictionary...
[Info] Dictionary Loaded





In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transformer = Transformer(
    dataset.n_words,
    dataset.n_words,
    src_pad_idx=dataset.PAD_token_id,
    trg_pad_idx=dataset.PAD_token_id,
    trg_emb_prj_weight_sharing=proj_share_weight,
    emb_src_trg_weight_sharing=embs_share_weight,
    d_k=d_k,
    d_v=d_v,
    d_model=d_model,
    d_word_vec=d_model,
    d_inner=d_inner_hid,
    n_layers=n_layers,
    n_head=n_head,
    dropout=dropout    
)

model = transformer.to(device)

model.load_state_dict((torch.load(
        save_model_path, map_location=device)))

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 10.91 GiB total capacity; 88.33 MiB already allocated; 3.56 MiB free; 94.00 MiB reserved in total by PyTorch)

In [None]:
src_pad_idx = dataset.PAD_token_id
trg_pad_idx = dataset.PAD_token_id
    
trg_bos_idx = dataset.SOS_token_id
trg_eos_idx = dataset.EOS_token_id
unk_idx = dataset.UNK_token_id

In [None]:
translator = Translator(
        model=model,
        beam_size=beam_size,
        max_seq_len=max_seq_len,
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        trg_bos_idx=trg_bos_idx,
        trg_eos_idx=trg_eos_idx).to(device)

In [8]:
# todo: modify to  Source / Target / Input / Output
idx2word = dataset.idx2word
with open(output_file, 'w') as f:
    for i, (seq1, seq2) in enumerate(tqdm(data_loader)):
        input_seq = seq1.to(device)
        trg_seq = seq2[0].tolist()
        
        pred_seq = translator.translate_sentence(input_seq)
        pred_line = ' '.join(idx2word[idx] for idx in pred_seq)
        pred_line = pred_line.replace(dataset.SOS_token, '').replace(dataset.EOS_token, '')
            
        input_line = ' '.join(idx2word[idx] for idx in input_seq[0].tolist())
        input_line = input_line.replace(dataset.SOS_token, '').replace(dataset.EOS_token, '')

        src_line, trg_line = text_dataset.sentences[i]
        
#         print('*' * 80)
#         print('Source: ', src_line.strip())
#         print('Target: ', trg_line.strip())
#         print('Input: ', input_line.strip())
#         print('Predict: ', pred_line.strip())
        
        f.write('*' * 80)
        f.write('\n')
        f.write('Source: {}\n'.format(src_line.strip()))
        f.write('Target: {}\n'.format(trg_line.strip()))
        f.write('Input: {}\n'.format(input_line.strip()))
        f.write('Predict: {}\n'.format(pred_line.strip()))


  2%|▏         | 368/20000 [00:00<00:10, 1824.22it/s]

********************************************************************************
Source:  What is the food you can eat every day for breakfast lunch and dinner ?
Target:  Is it healthy to eat fish every day ?
Input:  What is the food you can eat every day for breakfast lunch and dinner ?
********************************************************************************
Source:  What are some of the best ways to write in exams ?
Target:  How do we write the exam ?
Input:  What are some of the best ways to write in exams ?
********************************************************************************
Source:  Which are best novels one should must read before die ?
Target:  What are the top novels I should read before I die ? And why ?
Input:  Which are best novels one should must read before die ?
********************************************************************************
Source:  What will be your New Year s resolution for ?
Target:  What is your resolution for ?
Input:  What will 

  4%|▎         | 742/20000 [00:00<00:10, 1846.79it/s]

********************************************************************************
Source:  What is the law of interaction ? What are some examples of it ?
Target:  What is an example of the law of interaction ?
Input:  What is the law of interaction ? What are some examples of it ?
********************************************************************************
Source:  How does a pen work ?
Target:  How do my pen work ? all pens I mean ?
Input:  How does a pen work ?
********************************************************************************
Source:  What is hydrogen bonding ?
Target:  What is hydrogen bond ?
Input:  What is hydrogen bonding ?
********************************************************************************
Source:  How can I lose kg weight ?
Target:  I m fat . How do I lose weight ?
Input:  How can I lose kg weight ?
********************************************************************************
Source:  Which country is most likely to start world war III ?
Targe

  6%|▌         | 1107/20000 [00:00<00:10, 1835.03it/s]

********************************************************************************
Source:  Which Is the best book on psychology ?
Target:  What are the best books on human psychology ?
Input:  Which Is the best book on psychology ?
********************************************************************************
Source:  How do I ask questions on here ?
Target:  How can I ask my question on Quora ?
Input:  How do I ask questions on here ?
********************************************************************************
Source:  Why is talking to girls online about my fetish easier ?
Target:  Why is talking about my fetish online easier ?
Input:  Why is talking to girls online about my fetish easier ?
********************************************************************************
Source:  Who voted for Trump ?
Target:  Who voted for Trump ? Where are these people ?
Input:  Who voted for Trump ?
********************************************************************************
Source:  What 

  7%|▋         | 1480/20000 [00:00<00:10, 1849.63it/s]

********************************************************************************
Source:  Who is the most popular person on Quora ?
Target:  Who is the most popular person answering questions on Quora ?
Input:  Who is the most popular person on Quora ?
********************************************************************************
Source:  Which are the top five biggest scams in India ?
Target:  What are biggest scams in India by fraud saints ?
Input:  Which are the top five biggest scams in India ?
********************************************************************************
Source:  How can I become Prime Minister of the UK ? What are the necessary steps I need to take ?
Target:  How can I become Prime Minister of the UK ?
Input:  How can I become Prime Minister of the UK ? What are the necessary steps I need to take ?
********************************************************************************
Source:  Can I see who viewed my Google profile ?
Target:  How can I see who viewe

  9%|▉         | 1854/20000 [00:01<00:09, 1858.87it/s]

********************************************************************************
Source:  What is a turbine ?
Target:  What is meant by turbine ?
Input:  What is a turbine ?
********************************************************************************
Source:  What mistakes you should avoid while optimizing the website ?
Target:  What mistakes you should avoid while optimizing website ?
Input:  What mistakes you should avoid while optimizing the website ?
********************************************************************************
Source:  What do guys find most attractive about a girl ?
Target:  What things do guys find attractive in girls ?
Input:  What do guys find most attractive about a girl ?
********************************************************************************
Source:  How should I improve my English speaking and writing skills ?
Target:  How can I improve my English writing skills ?
Input:  How should I improve my English speaking and writing skills ?
********

 10%|█         | 2035/20000 [00:01<00:09, 1801.53it/s]

********************************************************************************
Source:  How can I do self study effectively ?
Target:  How do you study effectively ?
Input:  How can I do self study effectively ?
********************************************************************************
Source:  What is the best way to store my photos digitally ?
Target:  What is the best way to save store and organize digital photos ?
Input:  What is the best way to store my photos digitally ?
********************************************************************************
Source:  Do people really believe you can sell your soul to the devil ?
Target:  Can we really sell our soul to a devil ?
Input:  Do people really believe you can sell your soul to the devil ?
********************************************************************************
Source:  How can you fix a Nook that is not charging ?
Target:  How can I fix my Nook if it s not charging ?
Input:  How can you fix a Nook that is not cha

 12%|█▏        | 2413/20000 [00:01<00:09, 1840.70it/s]

********************************************************************************
Source:  What is the song Hallelujah about ?
Target:  What is the Leonard Cohen song Hallelujah about ?
Input:  What is the song Hallelujah about ?
********************************************************************************
Source:  How do you do a startup ?
Target:  What should one do to do a startup ?
Input:  How do you do a startup ?
********************************************************************************
Source:  How can I improve my communication and verbal skills ?
Target:  How can I improve my communication skills specially pronunciation skill ?
Input:  How can I improve my communication and verbal skills ?
********************************************************************************
Source:  When will the Effiel Tower collapse ?
Target:  Will the Effiel Tower collapse ?
Input:  When will the Effiel Tower collapse ?
********************************************************************

 14%|█▍        | 2786/20000 [00:01<00:09, 1852.81it/s]

********************************************************************************
Source:  What are some ways to cheer up a depressed person ?
Target:  How do you cheer up a depressed person ?
Input:  What are some ways to cheer up a depressed person ?
********************************************************************************
Source:  What is the difference between a treaty and a convention ?
Target:  What is the difference between international treaty and convention ?
Input:  What is the difference between a treaty and a convention ?
********************************************************************************
Source:  Should I apply for pan card online or offline ?
Target:  How do I apply and receive pan card in hours with new facilities ?
Input:  Should I apply for pan card online or offline ?
********************************************************************************
Source:  Why do women moan during sex ?
Target:  Why do women moan while making love ?
Input:  Why do w

 16%|█▋        | 3274/20000 [00:01<00:09, 1844.03it/s]

********************************************************************************
Source:  When and by whom was the camera invented ?
Target:  When were cameras invented ? How were they invented ?
Input:  When and by whom was the camera invented ?
********************************************************************************
Source:  How do I stop people from messing with my question ?
Target:  How do I stop people from editing my question to a different one or reverse the changes ?
Input:  How do I stop people from messing with my question ?
********************************************************************************
Source:  Can Donald Trump still win the U .S . Presidential Election ?
Target:  Do you think Trump could win the presidency ?
Input:  Can Donald Trump still win the U .S . Presidential Election ?
********************************************************************************
Source:  Why do people ask questions on Quora that can easily be answered by Google ?
Target




KeyboardInterrupt: 

[28, 154, 76, 1703, 221, 365, 38, 1391, 1265, 22, 30, 2]