In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import sys
# FOLDER_PATH = '/home/taindp/Jupyter/custom_dqn'
FOLDER_PATH = '/content/drive/MyDrive/Colab Notebooks/go-orient'
# FOLDER_PATH  = '/home/taindp/Jupyter/reference/dqn_ref'
sys.path.append(FOLDER_PATH)

In [3]:
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
print(tf.compat.v1.get_default_graph())

<tensorflow.python.framework.ops.Graph object at 0x7f664a43a470>


In [None]:
# sys.path

In [4]:
%load_ext autoreload
%autoreload 2
from user_simulator import UserSimulator
from error_model_controller import ErrorModelController
from dqn_agent import DQNAgent
# from dqn_agent import get_action
from state_tracker import StateTracker
import pickle, argparse, json, math
from utils import remove_empty_slots
from user import User
import time
import json
from tqdm import tqdm
%cd $FOLDER_PATH
# from tqdm import tqdm

/content/drive/MyDrive/Colab Notebooks/go-orient


In [5]:
def run_round(state, warmup=False):
#     print('State of run_round',state)
#     state : array 12*14
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    agent_action_index, agent_action = dqn_agent.get_action(state, use_rule=warmup)
    # print('agent_action_index',agent_action_index)
    # print('agent_action',agent_action)
    # 2) Update state tracker with the agent's action
    state_tracker.update_state_agent(agent_action)
    # print('agent_action_update',agent_action)
    # 3) User takes action given agent action
    user_action, reward, done, success = user.step(agent_action)
    # print('user_action',user_action)
    # print('reward',reward)
    # print('done',done)
    # print('success',success)
    if not done:
        # 4) Infuse error into semantic frame level of user action
        emc.infuse_error(user_action)
    # 5) Update state tracker with user action
    state_tracker.update_state_user(user_action)
    # print('update_state_user',user_action)
    # 6) Get next state and add experience
    next_state = state_tracker.get_state(done)
    dqn_agent.add_experience(state, agent_action_index, reward, next_state, done)
    return next_state, reward, done, success


def warmup_run():
    """
    Runs the warmup stage of training which is used to fill the agents memory.

    The agent uses it's rule-based policy to make actions. The agent's memory is filled as this runs.
    Loop terminates when the size of the memory is equal to WARMUP_MEM or when the memory buffer is full.

    """

    print('Warmup Started...')
    total_step = 0
    
    with tqdm(total=WARMUP_MEM) as pbar:
#         total_step = 0 
        while total_step != WARMUP_MEM and not dqn_agent.is_memory_full():
#             counter += 1
            # Reset episode
            episode_reset()
            done = False
            # Get initial state from state tracker
            state = state_tracker.get_state()
            # print('State warmup', state)
            # print('Type state',type(state))
            while not done:
                next_state, _, done, _ = run_round(state, warmup=True)
                total_step += 1
                state = next_state
            # print(total_step)
            # time.sleep(1)
#             pbar.update(1)
    print('...Warmup Ended')


def train_run():
    """
    Runs the loop that trains the agent.

    Trains the agent on the goal-oriented chatbot task. Training of the agent's neural network occurs every episode that
    TRAIN_FREQ is a multiple of. Terminates when the episode reaches NUM_EP_TRAIN.

    """

    print('Training Started...')
    episode = 0
    period_reward_total = 0
    period_success_total = 0
    success_rate_best = 0.76
    while episode < NUM_EP_TRAIN:
        episode_reset()
        episode += 1
        print('Current episode:',episode)
        done = False
        state = state_tracker.get_state()
        # print('Current state',state)
        while not done:
            next_state, reward, done, success = run_round(state)
            # print('Next state',next_state)
            # print('Reward', reward)
            # print('Done',done)
            period_reward_total += reward
            state = next_state
        print("success :{0}".format(success))

        period_success_total += success
        print('period_success_total',period_success_total)
        # Train
        if episode % TRAIN_FREQ == 0:
            # Check success rate
            success_rate = period_success_total / TRAIN_FREQ
            avg_reward = period_reward_total / TRAIN_FREQ
            print("episode {0}: success rate: {1}".format(episode, success_rate))

            # Flush
            if success_rate >= success_rate_best and success_rate >= SUCCESS_RATE_THRESHOLD:
                dqn_agent.empty_memory()
            # Update current best success rate
            if success_rate > success_rate_best:
                print('Episode: {} NEW BEST SUCCESS RATE: {} Avg Reward: {}' .format(episode, success_rate, avg_reward))
                success_rate_best = success_rate
                dqn_agent.save_weights()
            period_success_total = 0
            period_reward_total = 0
            # Copy
            dqn_agent.copy()
            # Train
            dqn_agent.train()
    print('...Training Ended')


def episode_reset():
    """
    Resets the episode/conversation in the warmup and training loops.

    Called in warmup and train to reset the state tracker, user and agent. Also get's the initial user action.

    """

    # First reset the state tracker
    state_tracker.reset()
    # Then pick an init user action
    user_action = user.reset()
    # Infuse with error
    emc.infuse_error(user_action)
    # And update state tracker
    state_tracker.update_state_user(user_action)
    # Finally, reset agent
    dqn_agent.reset()

In [6]:
CONSTANTS_FILE_PATH = f'{FOLDER_PATH}/constants.json'
constants_file = CONSTANTS_FILE_PATH

with open(constants_file) as f:
    constants = json.load(f)

In [7]:
constants

{'agent': {'batch_size': 16,
  'dqn_hidden_size': 80,
  'epsilon_init': 0.0,
  'gamma': 0.9,
  'learning_rate': 0.001,
  'load_weights_file_path': '',
  'max_mem_size': 500000,
  'save_weights_file_path': '',
  'vanilla': True},
 'db_file_paths': {'database': 'data/movie_db.pkl',
  'dict': 'data/movie_dict.pkl',
  'user_goals': 'data/movie_user_goals.pkl'},
 'emc': {'intent_error_prob': 0.0,
  'slot_error_mode': 0,
  'slot_error_prob': 0.05},
 'run': {'max_round_num': 20,
  'num_ep_run': 40000,
  'success_rate_threshold': 0.3,
  'train_freq': 100,
  'usersim': True,
  'warmup_mem': 1000}}

In [8]:
# Load file path constants
file_path_dict = constants['db_file_paths']
DATABASE_FILE_PATH = file_path_dict['database']
DICT_FILE_PATH = file_path_dict['dict']
USER_GOALS_FILE_PATH = file_path_dict['user_goals']

In [9]:
run_dict = constants['run']
USE_USERSIM = run_dict['usersim']
WARMUP_MEM = run_dict['warmup_mem']
NUM_EP_TRAIN = run_dict['num_ep_run']
TRAIN_FREQ = run_dict['train_freq']
MAX_ROUND_NUM = run_dict['max_round_num']
SUCCESS_RATE_THRESHOLD = run_dict['success_rate_threshold']

In [10]:
# database= json.load(open(DATABASE_FILE_PATH,encoding='utf-8'))
database = pickle.load(open(DATABASE_FILE_PATH, 'rb'), encoding='latin1')
remove_empty_slots(database)

In [12]:
type(database)

dict

In [None]:
database

In [None]:
# db_dict = json.load(open(DICT_FILE_PATH,encoding='utf-8'))[0]
db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

In [None]:
db_dict

In [14]:
# user_goals = json.load(open(USER_GOALS_FILE_PATH,encoding='utf-8'))
user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'), encoding='latin1')

In [16]:
user_goals

[{'diaact': 'request',
  'inform_slots': {'city': 'birmingham',
   'date': 'today',
   'moviename': 'zootopia',
   'numberofpeople': '1',
   'starttime': 'around 2pm',
   'state': 'al',
   'theater': 'carmike summit 16'},
  'request_slots': {}},
 {'diaact': 'request',
  'inform_slots': {'city': 'seattle',
   'date': 'tomorrow',
   'moviename': 'deadpool',
   'numberofpeople': '2',
   'starttime': '9:00 pm',
   'theater': 'amc pacific place 11 theater'},
  'request_slots': {}},
 {'diaact': 'request',
  'inform_slots': {'city': 'birmingham',
   'date': 'today',
   'moviename': 'deadpool',
   'numberofpeople': '4',
   'starttime': 'around 6pm',
   'state': 'al',
   'theater': 'carmike summit 16'},
  'request_slots': {}},
 {'diaact': 'request',
  'inform_slots': {'city': 'seattle',
   'date': 'tomorrow',
   'moviename': 'zootopia',
   'numberofpeople': '2',
   'starttime': '9:10 pm',
   'theater': 'regal meridian 16'},
  'request_slots': {}},
 {'diaact': 'request',
  'inform_slots': {'city

In [None]:
if USE_USERSIM:
    user = UserSimulator(user_goals, constants, database)
else:
    user = User(constants)
emc = ErrorModelController(db_dict, constants)
state_tracker = StateTracker(database, constants)
dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)

In [None]:
warmup_run()

  0%|          | 0/1000 [00:00<?, ?it/s]

Warmup Started...


  0%|          | 0/1000 [00:00<?, ?it/s]

...Warmup Ended





In [None]:
train_run()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
success :False
period_success_total 32
Current episode: 38341
success :False
period_success_total 32
Current episode: 38342
success :True
period_success_total 33
Current episode: 38343
success :True
period_success_total 34
Current episode: 38344
success :True
period_success_total 35
Current episode: 38345
success :True
period_success_total 36
Current episode: 38346
success :False
period_success_total 36
Current episode: 38347
success :True
period_success_total 37
Current episode: 38348
success :True
period_success_total 38
Current episode: 38349
success :True
period_success_total 39
Current episode: 38350
success :True
period_success_total 40
Current episode: 38351
success :False
period_success_total 40
Current episode: 38352
success :True
period_success_total 41
Current episode: 38353
success :True
period_success_total 42
Current episode: 38354
success :False
period_success_total 42
Current episode: 38355
success :True
p