In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import torchfly
torchfly.set_random_seed(1)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import regex as re
import random
import itertools
import tqdm
import time
import json

from torch.utils.tensorboard import SummaryWriter
from apex import amp
from allennlp.training.checkpointer import Checkpointer
from pytorch_transformers import AdamW, WarmupLinearSchedule, GPT2Tokenizer

from torchfly.criterions import SequenceFocalLoss, SequenceCrossEntropyLoss
from torchfly.decode import top_filtering
from gpt_model import GPT2SimpleLM
from text_utils import recoverText, normalize

In [4]:
# set tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.sep_token = "None"
# add speicial tokens in the same order as Roberta
tokenizer.add_tokens(["<s>", "<pad>", "</s>", "<unk>", "<mask>"])

5

In [5]:
class GPT2SmallConfig:
    vocab_size = 50257 + len(tokenizer.added_tokens_encoder)
    n_special = len(tokenizer.added_tokens_encoder)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 768
    n_layer = 12
    n_head = 12
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = False
    
class GPT2MediumConfig:
    vocab_size = len(tokenizer.added_tokens_encoder)
    n_special = len(tokenizer.added_tokens_encoder)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 1024
    n_layer = 24
    n_head = 16
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = True

In [6]:
model_A = GPT2SimpleLM(GPT2SmallConfig)
model_B = GPT2SimpleLM(GPT2SmallConfig)
model_A_states, model_B_states = torch.load("../Checkpoint (copy)/model_state_epoch_5.th")
model_A.load_state_dict(model_A_states)
model_B.load_state_dict(model_B_states)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [7]:
def align_keep_indices(batch_keep_indices):
    prev = batch_keep_indices[1]
    new_batch_keep_indices = [prev]

    for i in range(1, len(batch_keep_indices)):
        curr = batch_keep_indices[i]
        new = []

        for idx in curr:
            new.append(prev.index(idx))

        new_batch_keep_indices.append(new)
        prev = curr
        
    return new_batch_keep_indices


class MultiWOZDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.file_list = list(self.data.keys())
        self.tokenizer = tokenizer
        self.bos = tokenizer.encode("<s>")
        self.user_bos = tokenizer.encode("A:")
        self.system_bos = tokenizer.encode("B:")
        
        self.eos = [628, 198]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        file_name = self.file_list[index]
        full_dialog = self.data[file_name]['log']
        
        full_dialog_tokens = []
        cur_pos = 0
        
        for turn_dialog in full_dialog:
            # cur_pos = 0
            
            # user
            user = recoverText(turn_dialog['user_delex'])
            # user = recoverText(turn_dialog['user_delex'])
            user_tokens = self.user_bos + tokenizer.encode(user) + self.eos

            user_pos = torch.arange(cur_pos, cur_pos + len(user_tokens))
            cur_pos = user_pos[-1] + 1
            
            # belief span
#             belief_tokens = self.bos + \
#                             tokenizer.encode(";".join(turn_dialog['bspan_inform'][1:])) + \
#                             self.eos
#             belief_pos = torch.arange(cur_pos, cur_pos + len(belief_tokens))
#             cur_pos = belief_pos[-1]


            # Database
            if eval(turn_dialog['pointer'])[-2:] == (1, 0):
                booked = "book"
            elif eval(turn_dialog['pointer'])[-2:] == (0, 1):
                booked = "fail"
            else:
                booked = "none"
            
            if len(turn_dialog['match']) > 0:
                num_match = int(turn_dialog['match']) if int(turn_dialog['match']) < 4 else 4
            else:
                num_match = 0
                
            database = str(num_match) + ";" + booked + ";" + turn_dialog['turn_domain'].strip("[]") + ";"
            database_tokens = tokenizer.encode(database)
            database_pos = torch.arange(cur_pos, cur_pos + len(database_tokens))
            cur_pos = database_pos[-1] + 1
            
            # System
            system = recoverText(process_text(turn_dialog['resp'], turn_dialog['turn_domain'].strip("[]")))
            system_tokens = self.system_bos + tokenizer.encode(system) + self.eos
            system_pos = torch.arange(cur_pos, cur_pos + len(system_tokens))
            cur_pos = system_pos[-1] + 1
            
            user_tokens = torch.LongTensor(user_tokens)
            system_tokens = torch.LongTensor(system_tokens)
            database_tokens = torch.LongTensor(database_tokens)
            
            full_dialog_tokens.append((user_tokens, 
                                       user_pos,
                                       system_tokens, 
                                       system_pos,
                                       database_tokens,
                                       database_pos))
#             if system_pos[-1] > 1:
#                 break

        return full_dialog_tokens, file_name


def calculate_length(dialogs):
    total_sum = 0
    for turn_num in range(len(dialogs)):
        total_sum += len(dialogs[turn_num][1]) + \
                        len(dialogs[turn_num][3]) + \
                        len(dialogs[turn_num][5])        
    return total_sum


class Collate_Function:
    """This function handles batch collate.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.pad = self.tokenizer.encode("<pad>")[0]

    def __call__(self, unpacked_data):
        unpacked_data, file_names = zip(*unpacked_data)
        
        keep_indices = []
        for i, dialog in enumerate(unpacked_data):
            length = calculate_length(dialog)
            if length < 900:
                keep_indices.append(i)
        
        unpacked_data = [unpacked_data[idx] for idx in keep_indices]
        file_names = [file_names[idx] for idx in keep_indices]
    
        if len(unpacked_data) == 0:
            return None, None, None
        
        max_turn_len = max([len(item) for item in unpacked_data])
    
        batch_dialogs = []
        batch_keep_indices = []

        for turn_num in range(max_turn_len):

            keep_indices = []

            for batch_idx in range(len(unpacked_data)):
                if turn_num < len(unpacked_data[batch_idx]):
                    keep_indices.append(batch_idx)

            user_tokens = pad_sequence([unpacked_data[idx][turn_num][0] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=self.pad)
            user_pos = pad_sequence([unpacked_data[idx][turn_num][1] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=0)
            system_tokens = pad_sequence([unpacked_data[idx][turn_num][2] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=self.pad)
            system_pos = pad_sequence([unpacked_data[idx][turn_num][3] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=0)
            database_tokens = pad_sequence([unpacked_data[idx][turn_num][4] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=self.pad)
            database_pos = pad_sequence([unpacked_data[idx][turn_num][5] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=0)  

            user_mask = (user_tokens != self.pad).byte()
            system_mask = (system_tokens != self.pad).byte()
            database_mask = (database_tokens != self.pad).byte()


            batch_dialogs.append((user_tokens, user_pos, user_mask, 
                                  system_tokens, system_pos, system_mask, 
                                  database_tokens, database_pos, database_mask))
            batch_keep_indices.append(keep_indices)
            
        return batch_dialogs, batch_keep_indices, file_names

In [8]:
def filter_past(past, keep_indices):
    past = [item[:, keep_indices] for item in past]
    return past

def replace_punc(x):
    x = x.replace("<", "").replace(">", "")
    return x.replace(".", " .").replace(",", " .").replace("?", " ?").replace("?", " ?")

In [9]:
with open("../../yichi_data/val_data_dict.json") as f:
    val_data = json.load(f)
    
with open("../../yichi_data/test_data_dict.json") as f:
    test_data = json.load(f)

collate_func = Collate_Function(tokenizer)

val_dataset = MultiWOZDataset(val_data, tokenizer)
test_dataset = MultiWOZDataset(test_data, tokenizer)

eval_batch_size = 16

val_dataloader = DataLoader(dataset=val_dataset, 
                          shuffle=False,
                          batch_size=eval_batch_size, 
                          collate_fn=collate_func)

test_dataloader = DataLoader(dataset=test_dataset, 
                          shuffle=False,
                          batch_size=eval_batch_size, 
                          collate_fn=collate_func)

In [10]:
device = torch.device("cuda")
model_A = model_A.to(device)
model_B = model_B.to(device)

In [11]:
def process_text(text, domain):
    text = text.replace("[value_choice]", "[value_count]")
    text = text.replace("[value_people]", "[value_count]")
    text = text.replace("[value_starts]", "[value_count]")
    
    text = text.replace("[value_car]", '[taxi_type]')
    text = text.replace("[value_leave]", "[value_time]")
    text = text.replace("[value_arrive]", "[value_time]")
    text = text.replace("[value_price]", "[value_pricerange]")

    text = text.replace('[value_postcode]', f'[{domain}_postcode]')
    text = text.replace('[value_reference]', f'[{domain}_reference]')
    text = text.replace('[value_address]', f'[{domain}_address]')
    text = text.replace('[value_phone]', f'[{domain}_phone]')
    text = text.replace('[value_name]', f'[{domain}_name]')
    text = text.replace('[value_id]', f'[{domain}_id]')
    
    return text

## Generation

In [12]:
temperature = 0.7

In [13]:
def generate(batch_dialogs, batch_keep_indices):
    
    aligned_batch_keep_indices = align_keep_indices(batch_keep_indices)
    past = None
    generated_responses = [[] for i in range(batch_dialogs[0][0].shape[0])]

    mask = torch.ByteTensor([]).to(device)
    prev_batch_size = batch_dialogs[0][0].shape[0]

    with torch.no_grad():
        for turn_num in range(len(batch_keep_indices)):
            # data send to gpu
            dialogs = batch_dialogs[turn_num]
            dialogs = [item.to(device) for item in dialogs]

            user_tokens, user_pos, user_mask, \
                system_tokens, system_pos, system_mask, \
                database_tokens, database_pos, database_mask = dialogs

            # batch filtering algorithm
            keep_indices = aligned_batch_keep_indices[turn_num]

            if len(keep_indices) != prev_batch_size:
                past = filter_past(past, keep_indices)
                mask = mask[keep_indices, :]

            # define some initials
            cur_batch_size = user_tokens.shape[0]
            flags = np.ones(cur_batch_size)
            generated_tokens = [[] for i in range(cur_batch_size)]

            # feed in user
            mask = torch.cat([mask, user_mask], dim=-1)
            _, past = model_A(user_tokens, position_ids=user_pos, mask=mask, past=past)

            # response generation
            response = []

            # database tokens
            mask = torch.cat([mask, database_mask], dim=-1)
            _, past = model_B(database_tokens, position_ids=database_pos, mask=mask, past=past)

            # response generation
            prev_input = system_tokens[:, :2]
            cur_pos = system_pos[:, :2]
            temp_past = past
            temp_mask = F.pad(mask, pad=(0,2), value=1)

            # feed into B
            logits, temp_past = model_B(prev_input, position_ids=cur_pos, mask=temp_mask, past=temp_past)
            # set current position
            cur_pos = cur_pos[:, -1].unsqueeze(1) + 1


            for i in range(60):
                logits = logits[:, -1, :] / temperature
                # logits = top_filtering(logits, top_p=0.2)
                # probs = F.softmax(logits, -1)
                # prev_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                prev_tokens = torch.argmax(logits, dim=-1)
                np_prev_tokens = prev_tokens.cpu().numpy()
                # nucleus sampling


                # add to generated tokens list
                count = 0
                for idx, value in enumerate(flags):
                    if value != 0:
                        generated_tokens[idx].append(np_prev_tokens[count])
                        count += 1

                # filtering algorithm
                if np.any(np_prev_tokens == 628):
                    # set flags 0
                    count = 0
                    for idx, value in enumerate(flags):
                        if value == 1:
                            if np_prev_tokens[count] == 628:
                                flags[idx] = 0
                            count += 1
                    # compute which one to keep
                    keep_indices = np.argwhere(np_prev_tokens != 628).squeeze(1)
                    # filter
                    prev_tokens = prev_tokens[keep_indices.tolist()]
                    cur_pos = cur_pos[keep_indices.tolist(), :]
                    temp_mask = temp_mask[keep_indices.tolist(), :]
                    temp_past = [item[:, keep_indices.tolist()] for item in temp_past]
                    np_prev_tokens = np_prev_tokens[keep_indices.tolist()]

                if np.all(flags == 0):
                    break

                # prepare for the next token        
                temp_mask = F.pad(temp_mask, pad=(0, 1), value=1)
                logits, temp_past = model_B(prev_tokens.view(-1, 1), 
                                       position_ids=cur_pos, 
                                       mask=temp_mask, 
                                       past=temp_past)
                cur_pos = cur_pos + 1

            # real system_tokens feed in
            mask = torch.cat([mask, system_mask], dim=-1)
            _, past = model_B(system_tokens, position_ids=system_pos, mask=mask, past=past)

            # inject into generated_responses_list
            decoded_responses = [tokenizer.decode(item).replace("\n", "") for item in generated_tokens]
            count = 0
            for idx in batch_keep_indices[turn_num]:
                generated_responses[idx].append(decoded_responses[count])
                count += 1
                
    return generated_responses

In [14]:
all_test_pred = {}

for batch_dialogs, batch_keep_indices, file_names in tqdm.tqdm_notebook(test_dataloader):
    if batch_dialogs is None:
        continue
    
    generated_responses = generate(batch_dialogs, batch_keep_indices)
    for i, pred_dialog in enumerate(generated_responses):
        all_test_pred[file_names[i]] = pred_dialog

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))




## Evaluation 

In [15]:
from latent_dialog.evaluators import MultiWozEvaluator, BLEUScorer
from text_utils import recoverText, normalize
import pickle

In [16]:
# with open("all_test_pred.pkl", "wb") as f:
#     pickle.dump(all_test_pred, f)

In [17]:
# with open("all_test_pred.pkl", "rb") as f:
#     all_test_pred = pickle.load(f)

In [18]:
data_name = "test"
evaluator = MultiWozEvaluator(data_name)

evaluator.initialize()
eval_data = {}

for file_name in all_test_pred:
    eval_data[file_name] = [normalize(item) for item in all_test_pred[file_name]]
    
    for i, sentence in enumerate(eval_data[file_name]):
        evaluator.add_example(sentence, sentence)
        
report, successes, matches, failure_files = evaluator.evaluateModel(eval_data, mode='test')

In [19]:
print(report)

test Corpus Matches : 84.78%
test Corpus Success : 70.87%
Total number of dialogues: 999 


In [20]:
len(failure_files)

291

In [21]:
filename = 'PMUL3672.json'

for i in range(len(test_data[filename]['log'])):
    print("Usr:" + test_data[filename]['log'][i]['user'])
    print("Sys:" + test_data[filename]['log'][i]['resp'])

Usr:i am looking for information in cambridge
Sys:ok sure . what would you like to know ?
Usr:i have an upcoming conference in cambridge and need to figure out transportation . can you tell me about a train to take maybe ?
Sys:absolutely . where are you heading in from ? what day ?
Usr:i 'll be leaving london kings cross and heading to cambridge . i need to be there by 10:30 on tuesday . can you book this for 3 people ? reference please ?
Sys:i have a train arriving at [value_arrive] . would that do ?
Usr:yes . book for 3 people .
Sys:alright got you booked on train [value_id] , the total fee is [value_price] payable at the station . your reference number is : [value_reference] . anything else i can help with today ?
Usr:yes i am looking for someplace to go in the south for entertainment .
Sys:we have [value_choice] options , can i reccomend for you ?
Usr:which ever is nicer . i will need some info on it too .
Sys:i recommend [value_name] it is in [value_address] postcode [value_postco

In [22]:
filename = 'PMUL0079.json'

for i in range(len(test_data[filename]['log'])):
    print("Usr: " + test_data[filename]['log'][i]['user'])
    print("Sys: " + eval_data[file_name][i])

Usr: where is a 4 star hotel located in north cambridge ?
Sys: there are [value_count] attractions in cambridge . do you have a specific area in mind ?
Usr: sure , that could be nice
Sys: [attraction_name] is located in the [value_area] and is [value_pricerange] .
Usr: i actually do n't need reservations i just need the phone number , price range .
Sys: [attraction_name] is [value_pricerange] to enter .
Usr: okay . now could you help me find a restaurant in the expensive price range that is in the same area as the hotel ?
Sys: [train_id] leaves at [value_time] and arrives by [value_time] . would you like me to book it for you ?
Usr: before we do that , what is the name of the guest house ? and also , do they have free parking ?
Sys: your booking was successful . the reference number is [train_reference] .
Usr: could you recommend an expensive restaurant in the same area ?
Sys: you are welcome . have a great day !
Usr: yes , book me a table for 2 people at 12:15 on monday .


IndexError: list index out of range

In [None]:
print(len(test_data[filename]['log']))
len(eval_data[file_name])

In [None]:
filenames, _ = zip(*failure_files)

In [None]:
filenames

In [None]:
eval_data['PMUL2859.json']

In [None]:
with open("data/norm-multi-woz/test_dials.json") as f:
    gt_test_data = json.load(f)

In [None]:
[item['resp'] for item in test_data['PMUL2859.json']['log']]

In [None]:
[item['user'] for item in test_data['PMUL2859.json']['log']]

In [None]:
[item['turn_domain'] for item in test_data['PMUL2859.json']['log']]

In [None]:
test_data['PMUL2859.json']['sys']

In [None]:
evaluator.evaluateModel(eval_data, mode='rollout')

In [None]:
import json



In [None]:
rollout_test_pred = {}

for file_name in test_data:
    rollout_test_pred[file_name] = delex_data[file_name] 

In [None]:
generated_test_pred = {}

for file_name in test_data:
    generated_test_pred[file_name] = [item['text'].strip() 
                                      for i, item in enumerate(delex_data[file_name]['log']) 
                                      if i % 2 == 1]

In [None]:
data_name = "rollout"
evaluator = MultiWozEvaluator(data_name)

evaluator.initialize()
eval_data = {}

evaluator.evaluateModel(rollout_test_pred, mode='rollout')

In [None]:
data_name = "test"
evaluator = MultiWozEvaluator(data_name)

evaluator.initialize()
eval_data = {}

evaluator.evaluateModel(generated_test_pred, mode='test')