Skip to content

Commit

Permalink
modify default parameters, change nlu nlg reading for win system (rb …
Browse files Browse the repository at this point in the history
…--> r)
  • Loading branch information
intersun committed Sep 4, 2018
1 parent 62c8cb4 commit 15ea1f7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
11 changes: 7 additions & 4 deletions system/src/deep_dialog/nlg/nlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,13 @@ def translate_diaact(self, dia_act):


def load_nlg_model(self, model_path):
""" load the trained NLG model """

model_params = pickle.load(open(model_path, 'rb'))

""" load the trained NLG model """
import sys
if 'win' in sys.platform:
model_params = pickle.load(open(model_path, 'r'))
else:
model_params = pickle.load(open(model_path, 'rb'))

hidden_size = model_params['model']['Wd'].shape[0]
output_size = model_params['model']['Wd'].shape[1]

Expand Down
8 changes: 6 additions & 2 deletions system/src/deep_dialog/nlu/nlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ def generate_dia_act(self, annot):

def load_nlu_model(self, model_path):
""" load the trained NLU model """

model_params = pickle.load(open(model_path, 'rb'))
import sys
if 'win' in sys.platform:
model_params = pickle.load(open(model_path, 'r'))
else:
model_params = pickle.load(open(model_path, 'rb'))


hidden_size = model_params['model']['Wd'].shape[0]
output_size = model_params['model']['Wd'].shape[1]
Expand Down
14 changes: 7 additions & 7 deletions system/src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def load_actions(sys_req_slots, sys_inf_slots):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument('--dict_path', dest='dict_path', type=str, default='./deep_dialog/data_restaurant/dict_restaurant.p', help='path to the .json dictionary file')
parser.add_argument('--kb_path', dest='kb_path', type=str, default='./deep_dialog/data_restaurant/restaurant_kb.1k.p', help='path to the movie kb .json file')
parser.add_argument('--dict_path', dest='dict_path', type=str, default='./deep_dialog/data_restaurant/slot_dict.v2.p', help='path to the .json dictionary file')
parser.add_argument('--kb_path', dest='kb_path', type=str, default='./deep_dialog/data_restaurant/restaurant.kb.1k.v1.p', help='path to the movie kb .json file')
parser.add_argument('--act_set', dest='act_set', type=str, default='./deep_dialog/data_restaurant/dia_acts.txt', help='path to dia act set; none for loading from labeled file')
parser.add_argument('--slot_set', dest='slot_set', type=str, default='./deep_dialog/data_restaurant/slot_set.txt', help='path to slot set; none for loading from labeled file')
parser.add_argument('--goal_file_path', dest='goal_file_path', type=str, default='./deep_dialog/data_restaurant/user_goals_first_turn_template.part.restaurant.v2.p', help='a list of user goals')
parser.add_argument('--diaact_nl_pairs', dest='diaact_nl_pairs', type=str, default='./deep_dialog/data_restaurant/dia_act_nl_pairs.v6.json', help='path to the pre-defined dia_act&NL pairs')
parser.add_argument('--slot_set', dest='slot_set', type=str, default='./deep_dialog/data_restaurant/restaurant_slots.txt', help='path to slot set; none for loading from labeled file')
parser.add_argument('--goal_file_path', dest='goal_file_path', type=str, default='./deep_dialog/data_restaurant/user_goals_first.v1.p', help='a list of user goals')
parser.add_argument('--diaact_nl_pairs', dest='diaact_nl_pairs', type=str, default='./deep_dialog/data_restaurant/sim_dia_act_nl_pairs.v2.json', help='path to the pre-defined dia_act&NL pairs')

parser.add_argument('--max_turn', dest='max_turn', default=20, type=int, help='maximum length of each dialog (default=20, 0=no maximum length)')
parser.add_argument('--episodes', dest='episodes', default=1, type=int, help='Total number of episodes to run (default=1)')
Expand All @@ -84,8 +84,8 @@ def load_actions(sys_req_slots, sys_inf_slots):
parser.add_argument('--epsilon', dest='epsilon', type=float, default=0, help='Epsilon to determine stochasticity of epsilon-greedy agent policies')

# load NLG & NLU model
parser.add_argument('--nlg_model_path', dest='nlg_model_path', type=str, default='./deep_dialog/models/nlg/lstm_tanh_relu_[1468202263.38]_2_0.610.p', help='path to model file')
parser.add_argument('--nlu_model_path', dest='nlu_model_path', type=str, default='./deep_dialog/models/nlu/lstm_[1468447442.91]_39_80_0.921.p', help='path to the NLU model file')
parser.add_argument('--nlg_model_path', dest='nlg_model_path', type=str, default='./deep_dialog/models/nlg/restaurant/lstm_tanh_[1532068150.19]_98_99_294_0.983.p', help='path to model file')
parser.add_argument('--nlu_model_path', dest='nlu_model_path', type=str, default='./deep_dialog/models/nlu/restaurant/lstm_[1532107808.26]_68_74_20_0.997.p', help='path to the NLU model file')

parser.add_argument('--act_level', dest='act_level', type=int, default=0, help='0 for dia_act level; 1 for NL level')
parser.add_argument('--run_mode', dest='run_mode', type=int, default=0, help='run_mode: 0 for default NL; 1 for dia_act; 2 for both')
Expand Down

0 comments on commit 15ea1f7

Please sign in to comment.