In [1]:
import json
import random
import sys
import zipfile
import copy
import re
import jieba
from collections import defaultdict
import os
import pickle as pkl

import numpy as np
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from pprint import pprint

In [2]:
from template_generate import TemplateNLG

In [3]:
archive = zipfile.ZipFile('data/test.json.zip', 'r')
test_data = json.load(archive.open('test.json'))

In [4]:
def split_delex_sentence(sen):
    res_sen = ''
    pattern = re.compile(r'(\[[^\[^\]]+\])')
    slots = pattern.findall(sen)
    for slot in slots:
        sen = sen.replace(slot, '[slot]')
    sen = sen.split('[slot]')
    for part in sen:
        part = ' '.join(jieba.lcut(part))
        res_sen += part
        if slots:
            res_sen += ' ' + slots.pop(0) + ' '
    return res_sen

In [5]:
def act2intent(dialog_act: list):
    cur_act = copy.deepcopy(dialog_act)
    # facility = None  # for 酒店设施
    if '酒店设施' in cur_act[2]:
        # facility = cur_act[2].split('-')[1]
        if cur_act[0] == 'Inform':
            cur_act[2] = cur_act[2].split('-')[0] + '+' + cur_act[3]
        elif cur_act[0] == 'Request':
            cur_act[2] = cur_act[2].split('-')[0]
    if cur_act[0] == 'Select':
        cur_act[2] = '源领域+' + cur_act[3]
    intent = '+'.join(cur_act[:-1])
    if '+'.join(cur_act) == 'Inform+景点+门票+免费':
        intent = '+'.join(cur_act)
    if cur_act[0] == 'Inform' and cur_act[-1] == '无':
#         intent = 'Inform+主体+属性+无'
        intent = '+'.join(cur_act)
    return intent

In [6]:
def value_replace(sentences, dialog_act):
    cur_sen = sentences
    dialog_act = copy.deepcopy(dialog_act)
    intent_frequency = defaultdict(int)
    for act in dialog_act:
        intent = act2intent(copy.deepcopy(act))
        intent_frequency[intent] += 1
        if intent_frequency[intent] > 1:  # if multiple same intents...
            intent += str(intent_frequency[intent])

        if '酒店设施' in intent:
            try:
                sentences = sentences.replace('[' + intent + ']', act[2].split('-')[1])
                sentences = sentences.replace('[' + intent + '1]', act[2].split('-')[1])
            except Exception as e:
                print('Act causing problem in replacement:')
                pprint(act)
                raise e
#         if act[0] == 'Inform' and act[3] == "无":
#             sentences = sentences.replace('[主体]', act[1])
#             sentences = sentences.replace('[属性]', act[2])
        else:
            sentences = sentences.replace('[' + intent + ']', act[3])
            sentences = sentences.replace('[' + intent + '1]', act[3])  # if multiple same intents and this is 1st

    if '[' in sentences and ']' in sentences:
        print('\n\nValue replacement not completed!!! Current sentence: %s' % sentences)
        print('current da:')
        print(dialog_act)
        print('original sen:')
        print(cur_sen)
        pattern = re.compile(r'(\[[^\[^\]]+\])')
        slots = pattern.findall(sentences)
        for slot in slots:
            sentences = sentences.replace(slot, ' ')
#         print('current da:')
#         print(dialog_act)
#         print('intent_frequency:')
#         print(intent_frequency)
#         raise Exception('\n\nValue replacement not completed!!! Current sentence: %s' % sentences)
    return sentences


In [7]:
def get_bleu4(dialog_acts, golden_utts, gen_utts, data_key):
    das2utts = {}
    for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts):
        intent_frequency = defaultdict(int)
        for act in das:
            cur_act = copy.copy(act)

            # intent list
            facility = None  # for 酒店设施
            if '酒店设施' in cur_act[2]:
                facility = cur_act[2].split('-')[1]
                if cur_act[0] == 'Inform':
                    cur_act[2] = cur_act[2].split('-')[0] + '+' + cur_act[3]
                elif cur_act[0] == 'Request':
                    cur_act[2] = cur_act[2].split('-')[0]
            if cur_act[0] == 'Select':
                cur_act[2] = '源领域+' + cur_act[3]
            intent = '+'.join(cur_act[:-1])
            if '+'.join(cur_act) == 'Inform+景点+门票+免费':
                intent = '+'.join(cur_act)
            if cur_act[0] == 'Inform' and cur_act[-1] == '无':
#                 intent = 'Inform+主体+属性+无'
                intent = '+'.join(cur_act)

            intent_frequency[intent] += 1

            # utt content replacement
#             if act[0] == 'Inform' and '无' in intent:
#                 utt = utt.replace('+' + act[1] + '+', '+' + '&'.join(act[1]) + '+')
#                 utt = utt.replace('+' + act[2] + '+', '+' + '&'.join(act[2]) + '+')
#                 utt = utt.replace(act[1], '[主体]')
#                 utt = utt.replace(act[2], '[属性]')
#                 utt = utt.replace('&', '')
#                 gen = gen.replace('+' + act[1] + '+', '+' + '&'.join(act[1]) + '+')
#                 gen = gen.replace('+' + act[2] + '+', '+' + '&'.join(act[2]) + '+')
#                 gen = gen.replace(act[1], '[主体]')
#                 gen = gen.replace(act[2], '[属性]')
#                 gen = gen.replace('&', '')
            if (act[0] in ['Inform', 'Recommend'] or '酒店设施' in intent) and '无' not in intent:
                if act[3] in utt or (facility and facility in utt):
                    # value to be replaced
                    if '酒店设施' in intent:
                        value = facility
                    else:
                        value = act[3]

                    # placeholder
                    placeholder = '[' + intent + ']'
                    placeholder_one = '[' + intent + '1]'
                    placeholder_with_number = '[' + intent + str(intent_frequency[intent]) + ']'

                    if intent_frequency[intent] > 1:
                        utt = utt.replace(placeholder, placeholder_one)
                        utt = utt.replace(value, placeholder_with_number)
                        
                        gen = gen.replace(placeholder, placeholder_one)
                        gen = gen.replace(value, placeholder_with_number)
                    else:
                        utt = utt.replace(value, placeholder)
                        gen = gen.replace(value, placeholder)

        hash_key = ''
        for act in sorted(das):
            hash_key += act2intent(act)
        das2utts.setdefault(hash_key, {'refs': [], 'gens': []})
        das2utts[hash_key]['refs'].append(utt)
#         das2utts[hash_key]['gens'].append(gen)  # todo here template only generate one 'gen'
        das2utts[hash_key]['gens'].append({
            'das': das,
            'gen': gen
        })
#         print(utt, gen)
#     pprint(das2utts)
    refs, gens = [], []
    for das in das2utts.keys():
        # print('refs:')
        # print(len(das2utts[das]['refs']))
        # print('gens:')
        # print(len(das2utts[das]['gens']))
        assert len(das2utts[das]['refs']) == (len(das2utts[das]['gens']))
#         for gen in das2utts[das]['gens']:
#             # refs.append([s.split() for s in das2utts[das]['refs']])
#             refs.append([split_delex_sentence(s) for s in das2utts[das]['refs']])
#             gens.append(split_delex_sentence(gen))
        for gen_pair in das2utts[das]['gens']:
                lex_das = gen_pair['das']  # das w/ value
                gen = gen_pair['gen']
                lex_gen = value_replace(gen, lex_das)
#                 gens.append(' '.join(jieba.lcut(lex_gen)))
#                 refs.append([' '.join(jieba.lcut(value_replace(s, lex_das))) for s in das2utts[das]['refs']])
                gens.append([x for x in jieba.lcut(lex_gen) if x.strip()])
                refs.append([[x for x in jieba.lcut(value_replace(s, lex_das)) if x.strip()] for s in das2utts[das]['refs']])
                
    print('Start calculating bleu score...')
    from pprint import pprint
    print('refs:')
    pprint(refs[:10])
    print('gens:')
    pprint(gens[:10])
    with open(os.path.join('', 'generated_sens_%s.json' % data_key), 'w', encoding='utf-8') as f:
        json.dump({'refs': refs, 'gens': gens}, f, indent=4, sort_keys=True, ensure_ascii=False)
    print('generated_sens_%s.txt saved!' % data_key)
    bleu = corpus_bleu(refs, gens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
    return bleu

In [8]:
for data_key in ['sys', 'usr']:
    print('data_key: ', data_key)
    
    dialog_acts = []
    golden_utts = []
    gen_utts = []
    gen_slots = []
    intent_list = []
    dialog_acts2genutts = defaultdict(list)

    sen_num = 0
    sess_num = 0

    if os.path.isfile('dialog_acts_%s.pkl' % data_key) and os.path.isfile('golden_utts_%s.pkl' % data_key) and os.path.isfile('gen_utts_%s.pkl' % data_key):
        with open('dialog_acts_%s.pkl' % data_key, 'rb') as fda:
            dialog_acts = pkl.load(fda)
        with open('golden_utts_%s.pkl' % data_key, 'rb') as fgold:
            golden_utts = pkl.load(fgold)
        with open('gen_utts_%s.pkl' % data_key, 'rb') as fgen:
            gen_utts = pkl.load(fgen)
        for no, sess in list(test_data.items()):
            sess_num += 1
#             print('[%d/%d]' % (sess_num, len(test_data)))
            for turn in sess['messages']:
                if turn['role'] == 'usr' and data_key == 'sys':
                    continue
                elif turn['role'] == 'sys' and data_key == 'usr':
                    continue
                sen_num += 1
        print('sen_num: ', sen_num)
    else:
        for no, sess in list(test_data.items()):
            sess_num += 1
            print('[%d/%d]' % (sess_num, len(test_data)))
            for turn in sess['messages']:
                if turn['role'] == 'usr' and data_key == 'sys':
                    continue
                elif turn['role'] == 'sys' and data_key == 'usr':
                    continue
                sen_num += 1

                template_nlg = TemplateNLG(is_user=data_key == 'usr', mode='auto_manual')

                dialog_acts.append(turn['dialog_act'])
                golden_utts.append(turn['content'])  # slots **values**
                # gen_utts.append(model.generate(turn['dialog_act']))

                gen_utts.append(template_nlg.generate(turn['dialog_act']))  # slots **values**
                
#                 intent_list = []
#                 for act in turn['dialog_act']:
#                     intent_list.append(act2intent(act))
#                 dialog_acts2genutts['*'.join(intent_list)].append(gen_utt)
#                 dialog_acts2goldutts['*'.join(intent_list)].append(turn['content'])

#             if sess_num % 10 == 0:
#                 with open(os.path.join('', 'generated_sens_%s.json' % data_key), 'w', encoding='utf-8') as f:
#                     json.dump(dialog_acts2genutts, f, indent=4, sort_keys=True, ensure_ascii=False)
#                 print('generated_sens_%s.txt saved at iteration #{}!'.format(data_key.format(sess_num)))
#                 with open(os.path.join('', 'golden_sens_%s.json' % data_key), 'w', encoding='utf-8') as f:
#                     json.dump(dialog_acts2genutts, f, indent=4, sort_keys=True, ensure_ascii=False)
#                 print('golden_sens_%s.txt saved at iteration #{}!'.format(data_key.format(sess_num)))
                
        with open('dialog_acts_%s.pkl' % data_key, 'wb') as fda:
            pkl.dump(dialog_acts, fda)
        with open('golden_utts_%s.pkl' % data_key, 'wb') as fgold:
            pkl.dump(golden_utts, fgold)
        with open('gen_utts_%s.pkl' % data_key, 'wb') as fgen:
            pkl.dump(gen_utts, fgen)

    bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts, data_key)
    print("Calculate bleu-4")
    print("BLEU-4: %.4f" % bleu4)
    print("BLEU-4: %.4f" % bleu4)
    print('Model on {} session {} sentences data_key={}'.format(len(test_data), sen_num, data_key))

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/ht/8jl6l88n1_v610_t1jz69s440000gn/T/jieba.cache


data_key:  sys
sen_num:  4238


Loading model cost 0.696 seconds.
Prefix dict has been built succesfully.




Value replacement not completed!!! Current sentence: 电话是人均消费，人均消费是[Inform+餐馆+人均消费]。
current da:
[['Inform', '餐馆', '人均消费', '74元'], ['Inform', '餐馆', '电话', '人均消费']]
original sen:
电话是[Inform+餐馆+电话]，[Inform+餐馆+电话]是[Inform+餐馆+[Inform+餐馆+电话]]。


Value replacement not completed!!! Current sentence: 电话是010-59683511，人均消费为[Inform+餐馆+人均消费]。
current da:
[['Inform', '餐馆', '人均消费', '74元'], ['Inform', '餐馆', '电话', '人均消费']]
original sen:
电话是010-59683511，[Inform+餐馆+电话]为[Inform+餐馆+[Inform+餐馆+电话]]。
Start calculating bleu score...
refs:
[[['为',
   '您',
   '推荐',
   '鲜鱼',
   '口',
   '老字号',
   '美食街',
   '，',
   '人均',
   '消费',
   '75',
   '元',
   '，',
   '有',
   '您',
   '想',
   '吃',
   '的',
   '美食街',
   '哦',
   '。'],
  ['推荐',
   '您',
   '去',
   '鲜鱼',
   '口',
   '老字号',
   '美食街',
   '，',
   '但是',
   '它家',
   '的',
   '人均',
   '消费',
   '是',
   '75',
   '元',
   '。'],
  ['我',
   '推荐',
   '你',
   '去',
   '鲜鱼',
   '口',
   '老字号',
   '美食街',
   '尝尝',
   '，',
   '他家',
   '做',
   '的',
   '很',
   '好吃',
   '，',
   '不过',
   '