Skip to content

Commit

Permalink
Improve agenda policy (#52)
Browse files Browse the repository at this point in the history
* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy
  • Loading branch information
zqwerty committed Jul 15, 2020
1 parent dab6a68 commit c6372b1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 13 deletions.
90 changes: 77 additions & 13 deletions convlab2/task/multiwoz/goal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import random
from collections import Counter
from copy import deepcopy

from pprint import pprint
import numpy as np

from convlab2.util.multiwoz.dbquery import Database
from convlab2 import get_root_path

domains = {'attraction', 'hotel', 'restaurant', 'train', 'taxi', 'hospital', 'police'}
days = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']
Expand Down Expand Up @@ -134,23 +135,28 @@ class GoalGenerator:
"""User goal generator."""

def __init__(self,
goal_model_path=os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/goal/goal_model.pkl'),
goal_model_path=os.path.join(get_root_path(), 'data/multiwoz/goal/new_goal_model.pkl'),
corpus_path=None,
boldify=False):
boldify=False,
sample_info_from_trainset=True,
sample_reqt_from_trainset=False):
"""
Args:
goal_model_path: path to a goal model
corpus_path: path to a dialog corpus to build a goal model
corpus_path: path to a dialog corpus to build a goal model
boldify: highlight some information in the goal message
sample_info_from_trainset: if True, sample info slots combination from train set, else sample each slot independently
sample_reqt_from_trainset: if True, sample reqt slots combination from train set, else sample each slot independently
"""
self.goal_model_path = goal_model_path
self.corpus_path = corpus_path
self.db = Database()
self.boldify = do_boldify if boldify else null_boldify
self.sample_info_from_trainset = sample_info_from_trainset
self.sample_reqt_from_trainset = sample_reqt_from_trainset
self.train_database = self.db.query('train',[])
if os.path.exists(self.goal_model_path):
self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load(
self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist, self.slots_num_dist, self.slots_combination_dist = pickle.load(
open(self.goal_model_path, 'rb'))
print('Loading goal model is done')
else:
Expand All @@ -165,6 +171,12 @@ def __init__(self,
del self.ind_slot_dist['hospital']['reqt']['address']
del self.ind_slot_value_dist['hospital']['reqt']['address']

# print(self.slots_combination_dist['police'])
# print(self.slots_combination_dist['hospital'])
# pprint(self.ind_slot_dist)
# pprint(self.slots_num_dist)
# pprint(self.slots_combination_dist)

def _build_goal_model(self):
dialogs = json.load(open(self.corpus_path))

Expand Down Expand Up @@ -193,12 +205,24 @@ def _get_dialog_domains(dialog):
ind_slot_value_cnt = dict([(domain, {}) for domain in domains])
domain_cnt = Counter()
book_cnt = Counter()
self.slots_combination_dist = {domain: {} for domain in domains}
self.slots_num_dist = {domain: {} for domain in domains}

for d in dialogs:
for domain in domains:
if dialogs[d]['goal'][domain] != {}:
domain_cnt[domain] += 1
if 'info' in dialogs[d]['goal'][domain]:
if 'info' not in self.slots_combination_dist[domain]:
self.slots_combination_dist[domain]['info'] = {}
self.slots_num_dist[domain]['info'] = {}

slots = sorted(list(dialogs[d]['goal'][domain]['info'].keys()))
self.slots_combination_dist[domain]['info'].setdefault(tuple(slots), 0)
self.slots_combination_dist[domain]['info'][tuple(slots)] += 1
self.slots_num_dist[domain]['info'].setdefault(len(slots), 0)
self.slots_num_dist[domain]['info'][len(slots)] += 1

for slot in dialogs[d]['goal'][domain]['info']:
if 'invalid' in slot:
continue
Expand All @@ -210,6 +234,20 @@ def _get_dialog_domains(dialog):
continue
ind_slot_value_cnt[domain]['info'][slot][dialogs[d]['goal'][domain]['info'][slot]] += 1
if 'reqt' in dialogs[d]['goal'][domain]:
if 'reqt' not in self.slots_combination_dist[domain]:
self.slots_combination_dist[domain]['reqt'] = {}
self.slots_num_dist[domain]['reqt'] = {}
slots = sorted(dialogs[d]['goal'][domain]['reqt'])
if domain in ['police', 'hospital'] and 'postcode' in slots:
slots.remove('postcode')
else:
assert len(slots) > 0, print(sorted(dialogs[d]['goal'][domain]['reqt']),[slots])
if len(slots) > 0:
self.slots_combination_dist[domain]['reqt'].setdefault(tuple(slots), 0)
self.slots_combination_dist[domain]['reqt'][tuple(slots)] += 1
self.slots_num_dist[domain]['reqt'].setdefault(len(slots), 0)
self.slots_num_dist[domain]['reqt'][len(slots)] += 1

for slot in dialogs[d]['goal'][domain]['reqt']:
if 'reqt' not in ind_slot_value_cnt[domain]:
ind_slot_value_cnt[domain]['reqt'] = Counter()
Expand All @@ -227,6 +265,10 @@ def _get_dialog_domains(dialog):
continue
ind_slot_value_cnt[domain]['book'][slot][dialogs[d]['goal'][domain]['book'][slot]] += 1

# pprint(self.slots_num_dist)
# pprint(self.slots_combination_dist)
# for domain in domains:
# print(domain, len(self.slots_combination_dist[domain]['info']))
self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt)
self.ind_slot_dist = dict([(domain, {}) for domain in domains])
self.book_dist = {}
Expand Down Expand Up @@ -265,7 +307,8 @@ def _get_dialog_domains(dialog):
val] / slot_total
self.book_dist[domain] = book_cnt[domain] / len(dialogs)

pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist),
pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist,
self.slots_num_dist, self.slots_combination_dist),
open(self.goal_model_path, 'wb'))

def _get_domain_goal(self, domain):
Expand All @@ -279,9 +322,15 @@ def _get_domain_goal(self, domain):
domain_goal = {'info': {}}
# inform
if 'info' in cnt_slot:
for slot in cnt_slot['info']:
if random.random() < cnt_slot['info'][slot] + pro_correction['info']:
if self.sample_info_from_trainset:
slots = random.choices(list(self.slots_combination_dist[domain]['info'].keys()),
list(self.slots_combination_dist[domain]['info'].values()))[0]
for slot in slots:
domain_goal['info'][slot] = nomial_sample(cnt_slot_value['info'][slot])
else:
for slot in cnt_slot['info']:
if random.random() < cnt_slot['info'][slot] + pro_correction['info']:
domain_goal['info'][slot] = nomial_sample(cnt_slot_value['info'][slot])

if domain in ['hotel', 'restaurant', 'attraction'] and 'name' in domain_goal['info'] and len(
domain_goal['info']) > 1:
Expand Down Expand Up @@ -330,9 +379,21 @@ def _get_domain_goal(self, domain):
continue
# request
if 'reqt' in cnt_slot:
reqt = [slot for slot in cnt_slot['reqt']
if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in
domain_goal['info']]
if self.sample_reqt_from_trainset:
not_in_info_slots = {}
for slots in self.slots_combination_dist[domain]['reqt']:
for slot in slots:
if slot in domain_goal['info']:
break
else:
not_in_info_slots[slots] = self.slots_combination_dist[domain]['reqt'][slots]
pprint(not_in_info_slots)
reqt = list(random.choices(list(not_in_info_slots.keys()),
list(not_in_info_slots.values()))[0])
else:
reqt = [slot for slot in cnt_slot['reqt']
if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in
domain_goal['info']]
if len(reqt) > 0:
domain_goal['reqt'] = reqt

Expand Down Expand Up @@ -670,3 +731,6 @@ def get_same_people_domain(user_goal, domain, slot):

return message, message_by_domain

if __name__ == '__main__':
goal_generator = GoalGenerator(corpus_path=os.path.join(get_root_path(), 'data/multiwoz/train.json'), sample_reqt_from_trainset=True)
pprint(goal_generator.get_user_goal())
Binary file added data/multiwoz/goal/new_goal_model.pkl
Binary file not shown.

0 comments on commit c6372b1

Please sign in to comment.