In [1]:
from gridworld import GridworldMdp
from agents import OptimalAgent, MyopicAgent, UncalibratedAgent
from mdp_interface import Mdp
from agent_runner import get_reward_from_trajectory, run_agent
import numpy as np

In [None]:
class Intervention:

    def __init__(self,trial_length=10,num_interventions=3,gamma=0.9):
        self.steps_left = trial_length
        self.interventions_left = num_interventions
        self.optimal_agent = OptimalAgent(gamma=gamma)

    def set_mdp(self,mdp):
        self.optimal_agent.set_mdp(mdp)

    def get_optimal_action(self,state):
        return self.optimal_agent.get_action(state)

    def will_intervene(self,state,agent):
        raise NotImplemented("Cannot call will_intervene for Intervention")

    def get_action(self,state,agent):
        if self.will_intervene(state,agent):
            self.interventions_left -= 1
            self.steps_left -= 1
            return self.get_optimal_action(state)
        self.steps_left -= 1
        return agent.get_action(state)

In [None]:
class RandomIntervention(Intervention):
    def will_intervene(self, state, agent):
        prob = self.interventions_left / self.steps_left
        return np.random.rand() < prob

In [None]:
class StrategicIntervention(Intervention):
    def __init__(self, trial_length=10, num_interventions=3, gamma=0.9, qval_threshold=2):
        super().__init__(trial_length=trial_length, num_interventions=num_interventions, gamma=gamma)
        self.qval_threshold = qval_threshold

    def will_intervene(self, state, agent):
        agent_action = agent.get_action()
        optimal_action = self.get_optimal_action()
        mu = self.optimal_agent.extend_state_to_mu(state)
        agent_qval = self.optimal_agent.qvalue(mu,agent_action)
        optimal_qval = self.optimal_agent.qvalue(mu,optimal_action)
        return optimal_qval - agent_qval > self.qval_threshold

In [None]:
def run_trial(agent, intervention, trial_length):

    mdp = gen_random_connected()
    env = Mdp(mdp)
    agent.set_mdp(mdp)
    intervention.set_mdp(mdp)
    trajectory = []

    for _ in range(trial_length):
        curr_state = env.get_current_state()
        action = intervention.get_action(curr_state,agent)
        next_state, reward = env.perform_action(action)
        minibatch = (curr_state, action, next_state, reward)
        agent.inform_minibatch(*minibatch)
        trajectory.append(minibatch)

    reward = get_reward_from_trajectory(trajectory,gamma)
    return reward