In [1]:
import json
import numpy as np
import torch
import re
from nltk.tokenize import word_tokenize

In [2]:
def clean_replace(s, r, t, forward=True, backward=False):
    def clean_replace_single(s, r, t, forward, backward, sidx=0):
        idx = s[sidx:].lower().find(r.lower())
        if idx == -1:
            return s, -1
        idx += sidx
        idx_r = idx + len(r)
        if backward:
            while idx > 0 and s[idx - 1]:
                idx -= 1
        elif idx > 0 and s[idx - 1] != ' ':
            return s, -1

        if forward:
            while idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
                idx_r += 1
        elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
            return s, -1
        return s[:idx] + t + s[idx_r:], idx_r

    sidx = 0
    while sidx != -1:
        s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
    return s

In [3]:
def value_key_map(db_data):
    requestable_keys = ['address', 'name', 'phone', 'postcode', 'food', 'area', 'pricerange']
    value_key = {}
    for db_entry in db_data:
        for k, v in db_entry.items():
            if k in requestable_keys:
                value_key[v] = k
    return value_key

def db_search(db, constraints):
    """when doing matching, remember to lower case"""
    match_results = []
    for entry in db:
        entry_values = ' '.join(entry.values()).lower()
        match = True
        for c in constraints:
            if c.lower() not in entry_values:
                match = False
                break
        if match:
            match_results.append(entry)
    return match_results

def replace_entity(response, vk_map, constraint):
    response = re.sub('[cC][., ]*[bB][., ]*\d[., ]*\d[., ]*\w[., ]*\w', '<postcode_SLOT>', response)
    response = re.sub('\d{5}\s?\d{6}', '<phone_SLOT>', response)
    constraint_str = ' '.join(constraint)
    for v, k in sorted(vk_map.items(), key=lambda x: -len(x[0])):
        start_idx = response.lower().find(v.lower())
        if start_idx == -1 \
                or (start_idx != 0 and response[start_idx - 1] != ' ') \
                or (v in constraint_str):
            continue
        if k not in ['name', 'address']:
            response = clean_replace(response, v, '<' + k + '_SLOT>', forward=True, backward=False)
        else:
            response = clean_replace(response, v, '<' + k + '_SLOT>', forward=False, backward=False)
    return response

In [5]:
with open("../CamRest676/CamRest676.json", "r") as f:
    raw_data = json.loads(f.read())
    
# read database
with open("../CamRest676/CamRestDB.json", "r") as f:
    db_data = json.loads(f.read())

In [6]:
vk_map = value_key_map(db_data)

In [13]:
all_data = []

for dial_id, dial in enumerate(raw_data):
    
    one_dialog = []
    
    for turn in dial['dial']:
        turn_num = turn['turn']
        constraint = []
        requested = []
        
        for slot in turn['usr']['slu']:
            if slot['act'] == 'inform':
                s = slot['slots'][0][1]
                if s not in ['dontcare', 'none']:
                    constraint.extend(word_tokenize(s))
            else:
                requested.extend(word_tokenize(slot['slots'][0][1]))
        
        degree = len(db_search(db_data, constraint))
        if degree > 6:
            degree = 6
        
        constraint.insert(0, '[inform]')
        requested.insert(0, '[request]')
        
        user = turn['usr']['transcript']
        real_response = turn['sys']['sent']
        replaced_response = replace_entity(real_response, vk_map, constraint)
        
        one_dialog.append({
            'dial_id': dial_id,
            'turn_num': turn_num,
            'user': user,
            'real_response': real_response,
            'replaced_response': replaced_response,
            'degree': degree,
            'bspan_inform': constraint,
            'bspan_request': requested,
        })
    
    all_data.append(one_dialog)

In [17]:
indices = np.arange(len(all_data))
# np.random.shuffle(indices)
train_data = indices[:408]
val_data = indices[408:544]
test_data = indices[544:]

In [18]:
train_data = [all_data[idx] for idx in train_data]
val_data = [all_data[idx] for idx in val_data]
test_data = [all_data[idx] for idx in test_data]

In [19]:
torch.save(train_data, "train_data.pkl")
torch.save(val_data, "val_data.pkl")
torch.save(test_data, "test_data.pkl")