In [1]:
import random, copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import time
import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class DQNAgent:
    """The DQN agent that interacts with the user."""

    def __init__(self):
        """
        The constructor of DQNAgent.

        The constructor of DQNAgent which saves constants, sets up neural network graphs, etc.

        """
        self.memory = []
        self.balanced_memory = []
        self.num_samples_epoch = 0
        self.memory_index = 0
        self.max_memory_size = 20000 #200000, 125000, 10000,5000? #last stable version 15000
        self.eps = 0.2 #0.5*(0.9**5) #0.5*(0.9**2)
        self.lr = 1e-6 #1e-8, 1e-6, 1e-3
        self.gamma = 0.9 #0.9, 0.5
        self.batch_size = 32 #16, 32
        self.hidden_size_1 = 32 #32
        self.hidden_size_2 = 8   #8
        self.model_num = 0
        
        with open('../hel_rule_data', 'r') as f:
              self.rule_data = json.load(f)
        #self.len_rule_datapoint = 6

        self.load_weights_file_path = "../dqn_models"
        self.save_weights_file_path = "../dqn_models"

        if self.max_memory_size < self.batch_size:
            raise ValueError('Max memory size must be at least as great as batch size!')

        self.state_size = 26 #the input size to dqn_agent model. 
                             #the size of the state=[actor, hel_action, hel_da, eld_action, eld_da, hel_ot,hel_l,hel_o]
                             #size_state=[1,8,13,6,13,2,2,2] --> 41
                             #the size of the state=[actor, eld_action, eld_da, hel_ot,hel_l,hel_o]
                             #size_state=[1,6,13,2,2,2] --> 26
                        
        self.num_action = 9 #the number of helper's actions 0-8
        self.num_da = 14 #the number of helper's das 0-13


        self.beh_model = self._build_model()
        self.tar_model = self._build_model()
        self.ce = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss()
        self.huber = nn.HuberLoss(reduction='mean', delta=1.0)
        self.mae = nn.L1Loss()

        #self.beh_optimizer = torch.optim.SGD(self.beh_model.parameters(), lr=self.lr, momentum=0.9)
        #self.tar_optimizer = torch.optim.SGD(self.tar_model.parameters(), lr=self.lr, momentum=0.9)
        self.beh_optimizer = torch.optim.Adam(self.beh_model.parameters(), lr=self.lr)
        self.tar_optimizer = torch.optim.Adam(self.tar_model.parameters(), lr=self.lr)

        self._load_weights()



    def _build_model(self):
        """Builds and returns model/graph of neural network."""

        model = helper_model(self.hidden_size_1,self.hidden_size_2, self.state_size,
                             self.num_action).to(device)
        return model


    def get_action(self, agent_input, elder_output, helper_ot_l_o, warmup):
        """
        Returns the action of the agent given a state.

        """
        if self.eps > random.random():
            agent_output = {}
            #agent_output['da_out'] = torch.Tensor([0]*self.num_da).to(device=device)
            #index_da = random.randint(0, self.num_da - 1)
            #agent_output['da_out'][index_da] = 1
            
            agent_output['action_out'] = torch.Tensor([0]*self.num_action).to(device=device)
            index_action = random.randint(0, self.num_action - 1)
            agent_output['action_out'][index_action] = 1

            return agent_output
        else:
            if warmup:
                return self._rule_action(elder_output, helper_ot_l_o)
            else:
                return self._dqn_action(agent_input)

    def prep_input(self, elder_output, helper_ot_l_o):
        helper_input_ot = torch.Tensor(2*[0]).to(device=device)
        helper_input_l = torch.Tensor(2*[0]).to(device=device)        
        helper_input_o = torch.Tensor(2*[0]).to(device=device)
        
        helper_input_actor = torch.Tensor(1*[0]).to(device=device)
        helper_input_eld_action = torch.Tensor(6*[0]).to(device=device)
        helper_input_eld_da = torch.Tensor(13*[0]).to(device=device)

        if helper_ot_l_o[0][1] == 1:
            helper_input_ot[0] = 1
        elif helper_ot_l_o[0][1] == 2:
            helper_input_ot[1] = 1
            
        if helper_ot_l_o[1][1] == 1:
            helper_input_l[0] = 1
        elif helper_ot_l_o[1][1] == 2:
            helper_input_l[1] = 1

        if helper_ot_l_o[2][1] == 1:
            helper_input_o[0] = 1
        elif helper_ot_l_o[2][1] == 2:
            helper_input_o[1] = 1
                
        if elder_output['action_out']:
            helper_input_eld_action[elder_output['action_out']-1] = 1
        else:
            helper_input_actor = torch.Tensor(1*[1]).to(device=device)
            
        if elder_output['da_out']:
            helper_input_eld_da[elder_output['da_out']-1] = 1

        
        helper_input = torch.cat((helper_input_ot, helper_input_l, helper_input_o,
                                  helper_input_actor, helper_input_eld_action, helper_input_eld_da), 0)
        
        return helper_input

    def _rule_action(self, elder_output, helper_ot_l_o):
        #print(elder_output, helper_ot_l_o)
        hel_action_out = {}
        hel_action_out['action_out'] = torch.Tensor(9*[0]).to(device=device)
        indx__ = [0]
        for datapoint in self.rule_data:
            #print(elder_output['action_out'].item(), datapoint['eld_action'])
            if (elder_output['da_out'].item() == datapoint['eld_da'] and
                elder_output['action_out'].item() == datapoint['eld_action'] and
                helper_ot_l_o[0][1] == datapoint['ot'] and helper_ot_l_o[1][1] == datapoint['l'] 
                and helper_ot_l_o[2][1] == datapoint['o']):
                
                indx__.append(datapoint['hel_action'])
                
        #print(indx__)
        exp_flag = False
        if indx__ != [0]:
            exp_flag = True
        indx_ = indx__[random.randint(0,len(indx__)-1)]        
        hel_action_out['action_out'][indx_] = 1
        #print(hel_action_out)
        return hel_action_out
    
    def get_expert_action(self, elder_output, helper_ot_l_o):
        #print(elder_output, helper_ot_l_o)
        hel_action_out = {}
        hel_action_out['action_out'] = torch.Tensor(9*[0]).to(device=device)
        indx__ = [0]
        for datapoint in self.rule_data:
            #print(elder_output['action_out'].item(), datapoint['eld_action'])
            if (elder_output['da_out'].item() == datapoint['eld_da'] and
                elder_output['action_out'].item() == datapoint['eld_action'] and
                helper_ot_l_o[0][1] == datapoint['ot'] and helper_ot_l_o[1][1] == datapoint['l'] 
                and helper_ot_l_o[2][1] == datapoint['o']):
                
                indx__.append(datapoint['hel_action'])
                
        #print(indx__)
        exp_flag = False
        if indx__ != [0]:
            exp_flag = True
        indx_ = indx__[random.randint(0,len(indx__)-1)]        
        hel_action_out['action_out'][indx_] = 1
        #print(hel_action_out)
        return hel_action_out, exp_flag
                
      

    def _dqn_action(self, state):
        """
        Returns a behavior model output given a state.

        """
        agent_output = self._dqn_predict_one(state)
        return agent_output


    def _dqn_predict_one(self, state, target=False):
        agent_output = self._dqn_predict(state, target)
        return agent_output

    def _dqn_predict(self, states, target=False):
        """
        Returns a model prediction given an array of states.

        """
        if target:
            return self.tar_model(states)
        else:
            return self.beh_model(states)

    def add_experience(self, state, action, reward, next_state, done, success):
        """
        Adds an experience made of the parameters to the memory.

        """

        if len(self.memory) < self.max_memory_size:
            self.memory.append(None)
        self.memory[self.memory_index] = (state, action, reward, next_state, done, success)
        self.memory_index = (self.memory_index + 1) % self.max_memory_size

        
        
    def empty_memory(self):
        """Empties the memory and resets the memory index."""

        self.memory = []
        self.balanced_memory = []
        self.memory_index = 0

    def is_memory_full(self):
        """Returns true if the memory is full."""

        return len(self.memory) == self.max_memory_size

    def train(self):
        """
        Trains the agent by improving the behavior model given the memory tuples.

        Takes batches of memories from the memory pool and processing them. The processing takes the tuples and stacks
        them in the correct format for the neural network and calculates the Bellman equation for Q-Learning.

        """
        # Calc. num of batches to run
        num_batches = len(self.memory) // self.batch_size
        print("num_batches=", num_batches)
        batchlosses_of_oneepoch = [] # each element is the loss of one batch

        for b in range(num_batches):
            batch = random.sample(self.memory, self.batch_size)
            #for sample in batch:
              #print(sample)

            states = [sample[0] for sample in batch]
            next_states = [sample[3] for sample in batch]

            beh_state_preds = []
            for state in states:
                beh_state_preds.append(self._dqn_predict_one(self.prep_input(state[0], state[1][1])))
            

            #if not self.vanilla:
             #   beh_next_states_preds = []  # For indexing for DDQN
              #  for state in next_states:
               #     beh_next_states_preds.append(self._dqn_predict(self.prep_input(state[0], state[1])))
            

            tar_next_state_preds = []
            for state in next_states:
                tar_next_state_preds.append(self._dqn_predict(self.prep_input(state[0], state[1][1]), target=True))  # For target value for DQN (& DDQN)
            

            inputs = []
            targets = []
            

            for i, (s, a, r, s_, d, success) in enumerate(batch):
                t = beh_state_preds[i]                                    
                #t['da_out'] = r + self.gamma * tar_next_state_preds[i]['da_out']
                t['action_out'] = r + self.gamma * (tar_next_state_preds[i]['action_out'])*(not success)
                
                inputs.append(self.prep_input(s[0],s[1][1]))
                targets.append(t)

            #das_t = [targets[j]["da_out"] for j in range(len(targets))]
            actions_t = [targets[j]["action_out"] for j in range(len(targets))]
            total_train_loss = 0
            self.beh_model.train()            
            for i in range(len(inputs)):
                to_input = inputs[i]
                #da_t = das_t[i]
                action_t = actions_t[i]
                self.beh_model.zero_grad()
                self.beh_optimizer.zero_grad()
                to_input = to_input.to(device=device)
                output = self.beh_model(to_input)
                loss = torch.zeros(1).to(device=device)
                loss += self.mae(output["action_out"], action_t.to(device=device))
                #loss += self.mae(output["da_out"], da_t.to(device=device))
                total_train_loss += loss.item()
                loss.backward()
                self.beh_optimizer.step()

            batchlosses_of_oneepoch.append( total_train_loss)

        print("Training loss:", np.mean(np.asarray(batchlosses_of_oneepoch)))
        return batchlosses_of_oneepoch



    def copy(self):
        """Copies the behavior model's weights into the target model's weights."""
        self.tar_model.load_state_dict(self.beh_model.state_dict())
        #self.tar_model.eval() #??????????????????????????????????????????
        torch.save(self.tar_model.state_dict(), '../dqn_models/tar_model.pt')
        self.tar_model.to(device=device)
        #self.tar_model.eval() #??????????????????????????????????????????


    def save_weights(self):
        """Saves the weights of both models in two files."""
        if not self.save_weights_file_path:
            return
        beh_save_file_path = self.save_weights_file_path
        #self.beh_model.eval() #??????????????????????????????????????????
        torch.save(self.beh_model.state_dict(), beh_save_file_path + '/beh_model_'+ str(self.model_num) +'.pt')
        self.load_weights_file_path = self.save_weights_file_path


    def _load_weights(self):
        """Loads the weights of both models from two h5 files."""

        if not self.load_weights_file_path:
            return
        beh_load_file_path = self.load_weights_file_path
        self.beh_model.load_state_dict(torch.load(beh_load_file_path + '/beh_model.pt'))
        self.beh_model.to(device=device)
        #self.beh_model.eval()

        tar_load_file_path = self.load_weights_file_path        
        self.tar_model.load_state_dict(torch.load(tar_load_file_path + '/tar_model.pt'))
        self.tar_model.to(device=device)
        #self.tar_model.eval()
        

class helper_model(nn.Module):
    def __init__(self, size1, size2, num_in, num_action_out):
        super(helper_model, self).__init__()
        self._size1 = size1
        self._size2 = size2
        self._layer1 = nn.Linear(num_in, size1)
        self._layer2 = nn.Linear(size1, size2)
        self._dropout = nn.Dropout(p=0.1)
        #self._da_out = nn.Linear(size2, num_da_out)
        self._action_out = nn.Linear(size2, num_action_out)

    def forward(self, feature_input):
        layer1_out = self._layer1(feature_input)
        layer2_out = self._layer2(layer1_out)
        hidden_out = self._dropout(layer2_out)
        to_return = {}
        #to_return["da_out"] = F.relu(self._da_out(hidden_out))
        to_return["action_out"] = F.relu(self._action_out(hidden_out))
        return to_return