In [None]:
%run state_tracker.ipynb
%run elder.ipynb
%run dqn_agent.ipynb
%run reward.ipynb
import matplotlib.pyplot as plt
import csv

In [None]:
USE_USERSIM = True
WARMUP_MEM = 4000

NUM_EP_TRAIN = 5000 #10000, 12000
TRAIN_FREQ = 100 #100, 50
MAX_ROUND_NUM = 25 #20, 15, 10
SUCCESS_RATE_THRESHOLD = 0.5 #0.3, 0.7, 0.4, 0.5


user = Elder()
dqn_agent = DQNAgent()
state_tracker = StateTracker()
reward_func = Reward()

In [None]:
def run_round(state, round_num, c, episode, warmup=False):
    # 1) Agent takes action given state of the world, i.e the action of elder and the helper's state
    agent_input = dqn_agent.prep_input(state[0], state[1][1])
    expert_output, exp_flag = dqn_agent.get_expert_action(state[0], state[1][1])
    agent_output = dqn_agent.get_action(agent_input, state[0], state[1][1], warmup)
    # 2) Update state tracker with the agent's action
    state_tracker.update_state_agent(agent_output)
    # 3) User takes action given agent action
    feature_input = user.prep_input()
    user_output = user.elder(feature_input)
    # 4) Update state tracker with user action
    state_tracker.update_state_user(user_output)

        #print(agent_output, user_output)
    # 5) Get next state and add experience
    next_state = [state_tracker.get_state_user(), state_tracker.get_state_agent()]
    _,_, done, success, forbidden = state_tracker.get_state_agent()
    reward = 0
    if round_num%3 == 0:
        reward = reward_func.Reward(state, next_state, success, done, forbidden, round_num+1)
        dqn_agent.add_experience(state, agent_output, reward, next_state, done, success)
        dqn_agent.num_samples_epoch += 1
    
    if exp_flag:
        c += 1
        
        #reward = reward_func.Reward(state, next_state, success, done, forbidden, round_num+1)
        #dqn_agent.add_experience(state, agent_output, reward, next_state, done, success)
        #dqn_agent.num_samples_epoch += 1
        
        if c <= 6 and episode % 100 <= 5:
            print('episode:',episode, 'user_output:', state[0]['action_out'].item(), 
                 'agent_state:', state[1][1], 'agent_output:', agent_output['action_out'].item(),
                 'expert_output:', torch.argmax(expert_output['action_out']).item(),
                 )
    
    return next_state, reward, done, success

def warmup_run():
    """
    Runs the warmup stage of training which is used to fill the agents memory.
    The agent uses it's rule-based policy to make actions. The agent's memory is filled as this runs.
    Loop terminates when the size of the memory is equal to WARMUP_MEM or when the memory buffer is full.
    """

    print('Warmup Started...')
    episode = 0
    total_step = 0
    while (total_step < WARMUP_MEM) and (not dqn_agent.is_memory_full()):
    #while total_step != WARMUP_MEM:
        # Reset episode
        episode_reset()
        c = 0
        episode += 1
        done = False
        round_num = 0
        # Get initial state from state tracker
        state = [state_tracker.get_state_user(), state_tracker.get_state_agent()]
        while not done:
            next_state, _, done, _ = run_round(state, round_num, c, episode, warmup=True)
            total_step += 1
            state = next_state
            round_num += 1
            if round_num >= MAX_ROUND_NUM: 
                done = True

    print('...Warmup Ended')


def train_run():
    """
    Runs the loop that trains the agent. Training of the agent's neural network occurs every episode that
    TRAIN_FREQ is a multiple of. Terminates when the episode reaches NUM_EP_TRAIN.

    """

    print('Training Started...')
    episode = 0
    period_reward_total = 0
    period_success_total = 0
    period_turns_total = 0
    success_rate_best = 0.0
    early_stop_flag = 0
    flag_turn_count = 0
    
    epoch_losses = [None] # each element is the loss of one epoch
    epoch_reward = []
    epoch_turns = []
    epoch_s_rate = []
    ep_start_train = []

    while episode < NUM_EP_TRAIN:
        episode_reset()
        c = 0
        episode += 1
        done = False
        round_num = 0
        state = [state_tracker.get_state_user(), state_tracker.get_state_agent()]

        
        while not done:                
            next_state, reward, done, success = run_round(state, round_num,c, episode, warmup=False)
            period_reward_total += reward
            state = next_state
            round_num += 1
            #print(round_num, end="-")
            if round_num >= MAX_ROUND_NUM: 
                done = True
        period_success_total += success
        period_turns_total += round_num
        
        #if (episode%TRAIN_FREQ <= 3) and (episode%TRAIN_FREQ >= 1):
         #   print('\nepisode:', episode)
          #  for k,hist in enumerate(state_tracker.agent_history):
              #  print('agent_action={}, user_da_action={}'.format(
                #    hist[0]['action_out'], (state_tracker.user_history[k+1]['da_out'],
                 #                           state_tracker.user_history[k+1]['action_out']))
                  #   )
                
        if episode % TRAIN_FREQ == 0:
            print('\nepisode=', episode, ' num_samples=',dqn_agent.num_samples_epoch,
                  ' mem_ind=', dqn_agent.memory_index)

            success_rate = period_success_total / TRAIN_FREQ
            avg_reward = period_reward_total / TRAIN_FREQ
            avg_turns = period_turns_total / TRAIN_FREQ
            epoch_s_rate.append(success_rate)
            epoch_reward.append(avg_reward)
            epoch_turns.append(avg_turns)
            print('success_rate={} avg_turns={} avg_reward={} prev_best_success_rate={}'.format(
                success_rate, avg_turns, avg_reward, success_rate_best))            
            
            # Check success rate
            if success_rate < success_rate_best:
                early_stop_flag += 1
            if early_stop_flag >= 5:
                torch.save(dqn_agent.beh_model.state_dict(), dqn_agent.save_weights_file_path + '/beh_model_early_' + str(flag_turn_count) + '.pt')
            print('early_stop_flag=', early_stop_flag)
            
            # Flush
            #if success_rate >= success_rate_best and success_rate >= SUCCESS_RATE_THRESHOLD:
            #    dqn_agent.empty_memory()


            # Update current best success rate
            if success_rate > success_rate_best:
                print('Episode: {} NEW BEST SUCCESS RATE: {} Avg Reward: {}'.format(
                    episode, success_rate, avg_reward))
                success_rate_best = success_rate
                early_stop_flag = 0
                flag_turn_count += 1
        
                dqn_agent.save_weights()
                dqn_agent.model_num += 1
                if success_rate >= 0.7: #0.7
                    #update epsilon
                    dqn_agent.eps = 0.95 * dqn_agent.eps
                    
            period_success_total = 0
            period_reward_total = 0
            period_turns_total = 0
            # Copy
            if episode <= 1000:
                if (episode) % (2*TRAIN_FREQ) == 0: #5
                    dqn_agent.copy()
            else:
                if episode % (10*TRAIN_FREQ) == 0: #10
                    dqn_agent.copy()                

            # Train
            print('epsilon: {}'.format(dqn_agent.eps))
            if dqn_agent.memory_index >= 0: #15000
                ep_start_train.append(episode)
                batchlosses_of_epoch1 = dqn_agent.train()
                dqn_agent.save_weights()
                dqn_agent.model_num += 1
                dqn_agent.num_samples_epoch = 0
                if not (len( batchlosses_of_epoch1) > 0):
                    epoch_losses.append(epoch_losses[-1])
                elif len( batchlosses_of_epoch1) > 0:
                    avgloss_of_epoch1 = np.mean(np.asarray(batchlosses_of_epoch1))
                    epoch_losses.append(avgloss_of_epoch1)

                    if episode % (5*TRAIN_FREQ) == 0:
                        fig, ax_left = plt.subplots()
                        ax_right = ax_left.twinx()

                        ax_left.plot(np.arange(1,len(epoch_turns)+1)*TRAIN_FREQ, epoch_turns, 'g')
                        ax_right.plot(np.arange(1,len(epoch_turns)+1)*TRAIN_FREQ, epoch_s_rate, 'r')
                        ax_left.set_xlabel('Episode')
                        ax_left.set_ylabel('Average Turns')
                        ax_right.set_ylabel('Success Rate')

                        fig1, ax_left1 = plt.subplots()
                        ax_right1 = ax_left1.twinx()
                        ax_left1.plot(np.arange(len(epoch_losses))*TRAIN_FREQ, epoch_losses, 'g')
                        ax_right1.plot(np.arange(1,len(epoch_turns)+1)*TRAIN_FREQ, epoch_reward, 'r')
                        ax_left1.set_xlabel('Episode')
                        ax_left1.set_ylabel('Training Loss')
                        ax_right1.set_ylabel('Average Reward')
                        plt.show()
            else:
                epoch_losses.append(epoch_losses[-1])            
            
    print('...Training Ended')
    print('Best Success Rate is: {}'.format(success_rate_best))
    
    return epoch_losses, epoch_reward, epoch_turns, epoch_s_rate


def episode_reset():
    """
    Resets the episode/conversation in the warmup and training loops.
    Called in warmup and train to reset the state tracker, user and agent. Also get's the initial user action.

    """
    # First reset the state tracker
    state_tracker.reset()
    # Then pick an init user action
    user_output = user.reset()
    user_output = user.transform_da(user_output, init=True)
    # And update state tracker
    state_tracker.update_state_user(user_output)

In [None]:
#warmup_run()

In [None]:
dqn_agent.memory_index

In [None]:
epoch_losses, epoch_reward, epoch_turns, epoch_s_rate = train_run()
with open('../results/curves_file.csv', mode='w') as curves_file:
    file_writer = csv.writer(curves_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)

    file_writer.writerow(['epoch_losses', epoch_losses])
    file_writer.writerow(['epoch_reward', epoch_reward])
    file_writer.writerow(['epoch_turns', epoch_turns])
    file_writer.writerow(['epoch_s_rate', epoch_s_rate])