In [None]:
state_tracker = StateTracker()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class StateEstimator(nn.Module):
    # hidden_size: size of intermediate layer (between bert+inputs and outputs)
    def __init__(self, hidden_size, hidden_size1):
        super(StateEstimator, self).__init__()
        self._hidden_size = hidden_size
        self._hidden_size1 = hidden_size1
        self._dummy = 128
        self._dummy1 = 32
        self._feat_hidden = nn.Linear(29+13+8, self._dummy1)
        self._hidden = nn.Linear(self._dummy1 , hidden_size)


        self._dropout = nn.Dropout(p=0.5)

        self._ot_out = nn.Linear(hidden_size, 3)
        self._l_out = nn.Linear(hidden_size, 4)
        self._o_out = nn.Linear(hidden_size, 4)
        self._da_out = nn.Linear(hidden_size, 14)
        self._action_out = nn.Linear(hidden_size, 7) #9
        
        
    def forward(self, feature_input):
        
        feature_out = self._feat_hidden(feature_input)
        hidden_out = self._hidden(feature_out)
        hidden_out = self._dropout(hidden_out)

        to_return = {}

        to_return["ot_out"] = F.relu(self._ot_out(hidden_out))
        to_return["l_out"] = F.relu(self._l_out(hidden_out))
        to_return["o_out"] = F.relu(self._o_out(hidden_out))
        to_return["da_out"] = F.relu(self._da_out(hidden_out))
        to_return["action_out"] = F.relu(self._action_out(hidden_out))

        return to_return

In [None]:
class Elder():
    """The DQN agent that interacts with the user."""

    def __init__(self):
        self.Objects = {'cup':['cup0','cup1','cup2'], 'ball':['ball0','ball1','ball2']}
        self.Locations = ['c1', 'd1', 'd2']


    def elder(self, feature_input):
        model = StateEstimator(128, 64)
        model.load_state_dict(torch.load("../sim_models/user_sim.pt", map_location=torch.device(device)))
        model.to(device=device)
        model.eval()

        feature_input = feature_input.to(device=device)

        with torch.no_grad():
            output = model(feature_input)
            
        return output

    
    def transform_da(self, user_output, init=True):
        if init:
            user_output['action_out'] = torch.argmax(user_output['action_out'])
            action = user_output['action_out'].item()

            action2da = {}
            action2da[0] = 0
            action2da[1] = 1
            action2da[2] = 1
            action2da[3] = 1
            action2da[4] = 2
            action2da[5] = 6
            action2da[6] = 7

            output = {}
            output['action_out'] = torch.Tensor(11*[0]).to(device=device)
            output['da_out'] = torch.Tensor(14*[0]).to(device=device)

            output['action_out'][action] = 1
            output['da_out'][action2da[action]] = 1
            output['ot_out'] = user_output['ot_out']
            output['l_out'] = user_output['l_out']
            output['o_out'] = user_output['o_out']           

        return output
    
    def reset(self):
        output = {}
        output['action_out'] = torch.Tensor(7*[0]).to(device=device)
        output['action_out'][random.randint(0,6)] = 1
        output['da_out'] = torch.Tensor(14*[0]).to(device=device)
        output['da_out'][random.randint(0,13)] = 1      
        output['ot_out'] = torch.Tensor(3*[0]).to(device=device)
        output['ot_out'][random.randint(0,2)] = 1
        output['l_out'] = torch.Tensor(4*[0]).to(device=device)
        output['l_out'][random.randint(0,3)] = 1
        output['o_out'] = torch.Tensor(4*[0]).to(device=device)
        output['o_out'][random.randint(0,3)] = 1

        return output


    def prep_input(self):
        helper_state = state_tracker.get_state_agent()
        helper_output = helper_state[0]
        elder_output = state_tracker.get_state_user()
        ot = helper_state[1][0]
        l = helper_state[1][1]
        o = helper_state[1][2]
               
        elder_input_ot_given = torch.Tensor([[0]]).to(device=device)
        if ot[0]:
            elder_input_ot_given = torch.Tensor([[1]]).to(device=device)
            
        elder_input_l_given = torch.Tensor([[0]]).to(device=device)
        if l[0]:
            elder_input_l_given = torch.Tensor([[1]]).to(device=device)
            
        elder_input_prev_actor = torch.Tensor([[0,0]]).to(device=device)           
        elder_input_ot = torch.Tensor([[0]*2]).to(device=device)
        elder_input_l = torch.Tensor([[0]*3]).to(device=device)
        elder_input_o = torch.Tensor([[0]*3]).to(device=device)
        elder_input_hel_pt_target = torch.Tensor([[0]*2]).to(device=device)
        elder_input_hel_pt = torch.Tensor([[0]*3]).to(device=device)
        elder_input_hel_ho_target = torch.Tensor([[0]*2]).to(device=device)
        elder_input_hel_ho = torch.Tensor([[0]*3]).to(device=device)
        elder_input_ho_type = torch.Tensor([[0]*5]).to(device=device)
        #elder_input_eld_action = torch.Tensor([[0]*6]).to(device=device)
        elder_input_eld_sees = torch.Tensor([[0]*2]).to(device=device)
        
        elder_input_action = torch.Tensor([[0]*8]).to(device=device)
        elder_input_da = torch.Tensor([[0]*13]).to(device=device)

        if elder_output:
            if elder_output['action_out']:
                elder_input_prev_actor = torch.Tensor([[1,0]]).to(device=device)
            else:
                elder_input_prev_actor = torch.Tensor([[0,1]]).to(device=device)
                
            if elder_output['ot_out']:
                elder_input_ot[0][elder_output['ot_out']-1] = 1
            if elder_output['l_out']:
                elder_input_l[0][elder_output['l_out']-1] = 1
            if elder_output['o_out']:
                elder_input_o[0][elder_output['o_out']-1] = 1
            #if elder_output['action_out']:
                #elder_input_eld_action[0][elder_output['action_out']-1] = 1
                
            elder_input_eld_sees[0][1] = 1 #how to update this?                

            
        if helper_output and helper_output['action_out']:
            #if helper_output['action_out'] in [3,5]:
                #elder_input_hel_pt_target[0][1] = 1
            #if helper_output['action_out'] in [4]:
                #elder_input_hel_pt_target[0][0] = 1
            if helper_output['action_out'] in [3,5]:
                elder_input_hel_ho_target[0][1] = 1
                elder_input_ho_type[0][4-1] = 1
            if helper_output['action_out'] in [4]:
                elder_input_hel_ho_target[0][0] = 1
                elder_input_ho_type[0][3-1] = 1
            
        if helper_output:
            if helper_output["action_out"]:
                elder_input_action[0][helper_output["action_out"].item()-1] = 1
            #if helper_output["da_out"]:
                if helper_output["action_out"].item() in [1,2]: #if helper action is req_ot--> da=3
                    elder_input_da[0][2] = 1
                elif helper_output["action_out"].item() in [3,4,5]: #if helper action is check--> da=4
                    elder_input_da[0][3] = 1
                elif helper_output["action_out"].item() == 6: #if helper action is ack--> da=2
                    elder_input_da[0][1] = 1                    
                elif helper_output["action_out"].item() == 7: #if helper action is yes--> da=6
                    elder_input_da[0][5] = 1
                elif helper_output["action_out"].item() == 8: #if helper action is no--> da=7
                    elder_input_da[0][6] = 1
                    
        feature_input = torch.cat((elder_input_prev_actor, elder_input_ot_given, elder_input_l_given,
                                   elder_input_ot, elder_input_l, elder_input_o, elder_input_hel_pt_target,
                                   elder_input_hel_pt, elder_input_hel_ho_target, elder_input_hel_ho,
                                   elder_input_ho_type, elder_input_eld_sees, elder_input_action, elder_input_da),
                                  1)

        feature_input = feature_input.to(device=device)

        return feature_input