Skip to content

Commit

Permalink
modify build message function for goal generation
Browse files Browse the repository at this point in the history
  • Loading branch information
zqwerty committed Jun 23, 2020
1 parent 6e23852 commit bdc9dba
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion convlab2/human_eval/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(self, opt, agent,
except Exception as e:
print(e)
num_goal_trials += 1
self.goal_message = goal_generator.build_message(self.goal)
self.goal_message, _ = goal_generator.build_message(self.goal)
self.goal_text = '<ul>'
for m in self.goal_message:
self.goal_text += '<li>' + m + '</li>'
Expand Down
45 changes: 45 additions & 0 deletions convlab2/task/multiwoz/generate_goals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
generate user goal for collecting new multiwoz data
"""

from convlab2.task.multiwoz.goal_generator import GoalGenerator
import random
import numpy as np
import json
import datetime
from pprint import pprint


def generate(total_num=1000, seed=42, output_file='goal.json'):
random.seed(seed)
np.random.seed(seed)
goal_generator = GoalGenerator()
goals = []
avg_domains = []
while len(goals) < total_num:
goal = goal_generator.get_user_goal()
# pprint(goal)
if 'police' in goal['domain_ordering']:
no_police = list(goal['domain_ordering'])
no_police.remove('police')
goal['domain_ordering'] = tuple(no_police)
del goal['police']
try:
message = goal_generator.build_message(goal)[1]
except:
continue
# print(message)
avg_domains.append(len(goal['domain_ordering']))
goals.append({
"goals": [],
"ori_goals": goal,
"description": message,
"timestamp": str(datetime.datetime.now()),
"ID": len(goals)
})
print('avg domains:', np.mean(avg_domains)) # avg domains: 1.827
json.dump(goals, open(output_file, 'w'), indent=4)


if __name__ == '__main__':
generate(output_file='goal20200623.json')
8 changes: 7 additions & 1 deletion convlab2/task/multiwoz/goal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def _adjust_info(self, domain, info):

def build_message(self, user_goal, boldify=null_boldify):
message = []
message_by_domain = []
mess_ptr4domain = 0
state = deepcopy(user_goal)

for dom in user_goal['domain_ordering']:
Expand Down Expand Up @@ -641,11 +643,15 @@ def get_same_people_domain(user_goal, domain, slot):
message.append(templates[dom]['fail_book ' + adjusted_slot].format(
self.boldify(user_goal[dom]['book'][adjusted_slot])))

dm = message[mess_ptr4domain:]
mess_ptr4domain = len(message)
message_by_domain.append(' '.join(dm))

if boldify == do_boldify:
for i, m in enumerate(message):
message[i] = message[i].replace('wifi', "<b>wifi</b>")
message[i] = message[i].replace('internet', "<b>internet</b>")
message[i] = message[i].replace('parking', "<b>parking</b>")

return message
return message, message_by_domain

0 comments on commit bdc9dba

Please sign in to comment.