In [28]:
from gridworld import GridworldMdp
from agents import OptimalAgent
from mdp_interface import Mdp
from agent_runner import get_reward_from_trajectory
import numpy as np

In [51]:
class Intervention:

    def __init__(self,trial_length=10,num_interventions=3,gamma=0.9):
        self.steps_left = trial_length
        self.interventions_left = num_interventions
        self.optimal_agent = OptimalAgent(gamma=gamma)

    def set_mdp(self,mdp):
        self.optimal_agent.set_mdp(mdp)

    def get_optimal_action(self,state):
        return self.optimal_agent.get_action(state)

    def will_intervene(self,state,agent_action):
        raise NotImplemented("Cannot call will_intervene for Intervention")

    def get_action(self,state,agent_action):
        if self.will_intervene(state,agent_action):
            self.interventions_left -= 1
            self.steps_left -= 1
            return self.get_optimal_action(state)
        self.steps_left -= 1
        return agent_action

In [52]:
class RandomIntervention(Intervention):

    def will_intervene(self, state, agent_action):
        prob = self.interventions_left / self.steps_left
        return np.random.rand() < prob

In [49]:
height = 8
width = 8
num_rewards = 4
noise = 0
gamma = 0.9
trial_length = 10
num_interventions = 3
num_trials = 20

In [53]:
def run_trial(agent, intervention):

    mdp = GridworldMdp.generate_random_connected(height,width,num_rewards,noise)
    env = Mdp(mdp)
    agent.set_mdp(mdp)
    intervention.set_mdp(mdp)
    trajectory = []

    for _ in range(trial_length):
        curr_state = env.get_current_state()
        agent_action = agent.get_action(curr_state)
        action = intervention.get_action(curr_state,agent_action)
        next_state, reward = env.perform_action(action)
        minibatch = (curr_state, action, next_state, reward)
        agent.inform_minibatch(*minibatch)
        trajectory.append(minibatch)

    reward = get_reward_from_trajectory(trajectory,gamma)
    return reward

    

In [54]:
results = []
for _ in range(num_trials):
    agent = OptimalAgent(gamma=gamma)
    intervention = RandomIntervention(num_interventions=num_interventions,gamma=gamma)
    reward = run_trial(agent,intervention)
    results.append(reward)
print(results)

[0.0, 12.262472396000001, 30.398624792000007, 5.436020897000002, 34.20184039100001, 27.633550391000007, 17.695858192999996, 0.0, 1.9492364960000008, 5.132568086000002, 18.410903593999997, 22.792193594, 6.486148313000001, 14.517993293, 3.836383487000002, 20.166358192999994, 14.467742594, 12.049626995, 13.182546797000002, 5.062472396000001, 32.27350919300001, 18.433862396, 7.213395797000002, 0.0, 18.410903593999997, 30.398624792000007, 26.595409193000002, 30.398624792000007, 22.136903594, 0.0, 0.0, 7.263646496000002, 34.20184039100001, 9.714148891999999, 11.120646797000001, 30.398624792000007, 0.0, 41.49994039100001, 24.559334792, 41.49994039100001, 18.433862396, 23.133550391000007, 27.633550391000007, 47.799940391000014, 16.679573792, 14.467742594, 44.095724792000006, 14.199397694000004, 8.833395797000001, 22.939334792]
