diff --git a/src/algorithms/epsilon_greedy_q_estimator.py b/src/algorithms/epsilon_greedy_q_estimator.py index 3a15bb9..5442f68 100644 --- a/src/algorithms/epsilon_greedy_q_estimator.py +++ b/src/algorithms/epsilon_greedy_q_estimator.py @@ -8,6 +8,7 @@ from src.utils.mixins import WithEstimatorMixin from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonGreedyConfig +from src.exceptions.exceptions import InvalidParamValue StateActionVec = TypeVar('StateActionVec') State = TypeVar('State') @@ -42,6 +43,18 @@ def __init__(self, config: EpsilonGreedyQEstimatorConfig): self.gamma: float = config.gamma self.env: Env = config.env self.weights: np.array = None + self.initialize() + + def initialize(self) -> None: + """Initialize the underlying weights + + Returns + ------- + + None + + """ + self.weights: np.array = np.zeros((self.env.n_states * self.env.n_actions)) def q_hat_value(self, state_action_vec: StateActionVec) -> float: """Returns the @@ -60,6 +73,10 @@ def q_hat_value(self, state_action_vec: StateActionVec) -> float: """ + + if self.weights is None: + raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?") + return self.weights.dot(state_action_vec) def update_weights(self, total_reward: float, state_action: Action, @@ -81,6 +98,10 @@ def update_weights(self, total_reward: float, state_action: Action, None """ + + if self.weights is None: + raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?") + v1 = self.q_hat_value(state_action_vec=state_action) v2 = self.q_hat_value(state_action_vec=state_action_) self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action @@ -99,14 +120,18 @@ def on_state(self, state: State) -> Action: An environment specific Action type """ - # compute the state values related to - # the given state + # get the approximation of the q-values + # given the state + q_values = [] - for action in range(self.env.n_actions): - state_action_vector = self.env.get_state_action_tile(action=action, state=state) - q_values.append(state_action_vector) + for a in range(self.env.n_actions): + tiled_vector = self.env.featurize_state_action(action=a, state=state) + q_values.append(self.q_hat_value(tiled_vector)) # choose an action at the current state action = self.eps_policy(q_values, state) + + # this is an integer get the ActionBase instead + action = self.env.get_action(action) return action diff --git a/src/algorithms/semi_gradient_sarsa.py b/src/algorithms/semi_gradient_sarsa.py index 416fbd2..00b6e0b 100644 --- a/src/algorithms/semi_gradient_sarsa.py +++ b/src/algorithms/semi_gradient_sarsa.py @@ -11,8 +11,10 @@ from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase, WithEstimatorMixin from src.utils.episode_info import EpisodeInfo +from src.spaces.time_step import TimeStep from src.exceptions.exceptions import InvalidParamValue + Policy = TypeVar('Policy') Env = TypeVar('Env') State = TypeVar('State') @@ -38,11 +40,16 @@ class SemiGradSARSA(object): def __init__(self, config: SemiGradSARSAConfig) -> None: self.config: SemiGradSARSAConfig = config + @property + def name(self) -> str: + return "Semi-Grad SARSA" + def actions_before_training(self, env: Env, **options) -> None: """Specify any actions necessary before training begins Parameters ---------- + env: The environment to train on options: Any key-value options passed by the client @@ -60,27 +67,74 @@ def actions_before_training(self, env: Env, **options) -> None: self.q_table[state, action] = 0.0 """ - def on_episode(self, env: Env, **options) -> EpisodeInfo: + def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None: + """Any actions to perform before the episode begins + + Parameters + ---------- + + env: The instance of the training environment + episode_idx: The training episode index + options: Any keyword options passed by the client code + + Returns + ------- + + None + + """ + + def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> None: + """Any actions after the training episode ends + + Parameters + ---------- + + env: The training environment + episode_idx: The training episode index + options: Any options passed by the client code + + Returns + ------- + + None + """ + + def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo: + """Train the algorithm on the episode + + Parameters + ---------- + + env: The environment to train on + options: Any keyword based options passed by the client code + + Returns + ------- + + An instance of EpisodeInfo + """ episode_reward = 0.0 episode_n_itrs = 0 # reset the environment - time_step = env.reset() + time_step = env.reset(**{"tiled_state": False}) # select a state state: State = time_step.observation #choose an action using the policy - action: Action = self.config.policy(state) + action: Action = self.config.policy.on_state(state) for itr in range(self.config.n_itrs_per_episode): # take action and observe reward and next_state - time_step = env.step(action) - reward: float = 0.0 + time_step: TimeStep = env.step(action, **{"tiled_state": False}) + + reward: float = time_step.reward episode_reward += reward - next_state: State = None + next_state: State = time_step.observation # if next_state is terminal i.e. the done flag # is set. then update the weights @@ -109,6 +163,7 @@ def _weights_update_episode_done(self, state: State, reward: float, Parameters ---------- + state: The current state reward: The reward to use action: The action we took at state diff --git a/src/algorithms/trainer.py b/src/algorithms/trainer.py index 6dc9b7f..f6f152e 100644 --- a/src/algorithms/trainer.py +++ b/src/algorithms/trainer.py @@ -60,24 +60,41 @@ def actions_before_training(self) -> None: self.iterations_per_episode = [] self.agent.actions_before_training(self.env) - def actions_before_episode_begins(self, **options) -> None: + def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None: """Perform any actions necessary before the training begins Parameters ---------- + env: The environment to train on + episode_idx: The training episode index options: Any options passed by the client code Returns ------- None + """ - self.agent.actions_before_episode_begins(**options) + self.agent.actions_before_episode_begins(env, episode_idx, **options) + + def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> None: + """Any actions after the training episode ends + + Parameters + ---------- + + env: The environment to train on + episode_idx: The training episode index + options: Any options passed by the client code + + Returns + ------- - def actions_after_episode_ends(self, **options): - self.agent.actions_after_episode_ends(**options) + None + """ + self.agent.actions_after_episode_ends(env, episode_idx, **options) - if options["episode_idx"] % self.configuration['output_msg_frequency'] == 0: + if episode_idx % self.configuration['output_msg_frequency'] == 0: if self.env.config.distorted_set_path is not None: self.env.save_current_dataset(options["episode_idx"]) @@ -93,10 +110,10 @@ def train(self): # reset the environment #ignore = self.env.reset() - self.actions_before_episode_begins(**{"env": self.env}) + self.actions_before_episode_begins(self.env, episode,) # train for a number of iterations #episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env) - episode_info: EpisodeInfo = self.agent.on_episode(self.env) + episode_info: EpisodeInfo = self.agent.on_episode(self.env, episode) print("{0} Episode score={1}, episode total avg distortion {2}".format(INFO, episode_info.episode_score, episode_info.total_distortion / episode_info.info["n_iterations"])) @@ -107,6 +124,6 @@ def train(self): self.iterations_per_episode.append(episode_info.info["n_iterations"]) self.total_rewards[episode] = episode_info.episode_score self.total_distortions.append(episode_info.total_distortion) - self.actions_after_episode_ends(**{"episode_idx": episode}) + self.actions_after_episode_ends(self.env, episode, **{}) print("{0} Training finished for agent {1}".format(INFO, self.agent.name)) diff --git a/src/examples/qlearning_three_columns.py b/src/examples/qlearning_three_columns.py index 98e80cd..b4e3e25 100644 --- a/src/examples/qlearning_three_columns.py +++ b/src/examples/qlearning_three_columns.py @@ -88,6 +88,8 @@ def get_ethinicity_hierarchy(): OUTPUT_MSG_FREQUENCY = 100 N_ROUNDS_BELOW_MIN_DISTORTION = 10 SAVE_DISTORTED_SETS_DIR = "/home/alex/qi3/drl_anonymity/src/examples/q_learn_distorted_sets/distorted_set" + REWARD_FACTOR = 0.95 + PUNISH_FACTOR = 2.0 # specify the columns to drop drop_columns = MockSubjectsLoader.FEATURES_DROP_NAMES + ["preventative_treatment", "gender", @@ -144,8 +146,8 @@ def get_ethinicity_hierarchy(): numeric_column_distortion_metric_type=NumericDistanceType.L2_AVG, string_column_distortion_metric_type=StringDistanceType.COSINE_NORMALIZE, dataset_distortion_type=DistortionCalculationType.SUM) - env_config.reward_factor = 0.95 - env_config.punish_factor = 2.0 + env_config.reward_factor = REWARD_FACTOR #0.95 + env_config.punish_factor = PUNISH_FACTOR #2.0 # create the environment env = DiscreteStateEnvironment(env_config=env_config) diff --git a/src/examples/semi_gradient_sarsa.py b/src/examples/semi_gradient_sarsa.py new file mode 100644 index 0000000..2818bb0 --- /dev/null +++ b/src/examples/semi_gradient_sarsa.py @@ -0,0 +1,167 @@ +import random +from pathlib import Path +import numpy as np + +from src.algorithms.semi_gradient_sarsa import SemiGradSARSAConfig, SemiGradSARSA +from src.utils.serial_hierarchy import SerialHierarchy +from src.spaces.tiled_environment import TiledEnv, TiledEnvConfig, Layer +from src.spaces.discrete_state_environment import DiscreteStateEnvironment +from src.datasets.datasets_loaders import MockSubjectsLoader, MockSubjectsData +from src.spaces.action_space import ActionSpace +from src.spaces.actions import ActionIdentity, ActionStringGeneralize, ActionNumericBinGeneralize +from src.algorithms.trainer import Trainer +from src.policies.epsilon_greedy_policy import EpsilonDecayOption +from src.algorithms.epsilon_greedy_q_estimator import EpsilonGreedyQEstimatorConfig, EpsilonGreedyQEstimator +from src.utils.distortion_calculator import DistortionCalculationType, DistortionCalculator +from src.utils.numeric_distance_type import NumericDistanceType +from src.utils.string_distance_calculator import StringDistanceType +from src.utils.reward_manager import RewardManager + + +N_LAYERS = 5 +N_BINS = 10 +N_EPISODES = 1000 +OUTPUT_MSG_FREQUENCY = 100 +GAMMA = 0.99 +ALPHA = 0.1 +N_ITRS_PER_EPISODE = 30 +EPS = 1.0 +EPSILON_DECAY_OPTION = EpsilonDecayOption.CONSTANT_RATE #.INVERSE_STEP +EPSILON_DECAY_FACTOR = 0.01 +MAX_DISTORTION = 0.7 +MIN_DISTORTION = 0.3 +OUT_OF_MAX_BOUND_REWARD = -1.0 +OUT_OF_MIN_BOUND_REWARD = -1.0 +IN_BOUNDS_REWARD = 5.0 +N_ROUNDS_BELOW_MIN_DISTORTION = 10 +SAVE_DISTORTED_SETS_DIR = "/home/alex/qi3/drl_anonymity/src/examples/semi_grad_sarsa/distorted_set" +REWARD_FACTOR = 0.95 +PUNISH_FACTOR = 2.0 + + +def get_ethinicity_hierarchy(): + ethnicity_hierarchy = SerialHierarchy(values={}) + + ethnicity_hierarchy["Mixed White/Asian"] = "White/Asian" + ethnicity_hierarchy["White/Asian"] = "Mixed" + + ethnicity_hierarchy["Chinese"] = "Asian" + ethnicity_hierarchy["Indian"] = "Asian" + ethnicity_hierarchy["Mixed White/Black African"] = "White/Black" + ethnicity_hierarchy["White/Black"] = "Mixed" + + ethnicity_hierarchy["Black African"] = "African" + ethnicity_hierarchy["African"] = "Black" + ethnicity_hierarchy["Asian other"] = "Asian" + ethnicity_hierarchy["Black other"] = "Black" + ethnicity_hierarchy["Mixed White/Black Caribbean"] = "White/Black" + ethnicity_hierarchy["White/Black"] = "Mixed" + + ethnicity_hierarchy["Mixed other"] = "Mixed" + ethnicity_hierarchy["Arab"] = "Asian" + ethnicity_hierarchy["White Irish"] = "Irish" + ethnicity_hierarchy["Irish"] = "European" + ethnicity_hierarchy["Not stated"] = "Not stated" + ethnicity_hierarchy["White Gypsy/Traveller"] = "White" + ethnicity_hierarchy["White British"] = "British" + ethnicity_hierarchy["British"] = "European" + ethnicity_hierarchy["Bangladeshi"] = "Asian" + ethnicity_hierarchy["White other"] = "White" + ethnicity_hierarchy["Black Caribbean"] = "Caribbean" + ethnicity_hierarchy["Caribbean"] = "Black" + ethnicity_hierarchy["Pakistani"] = "Asian" + + ethnicity_hierarchy["European"] = "European" + ethnicity_hierarchy["Mixed"] = "Mixed" + ethnicity_hierarchy["Asian"] = "Asian" + ethnicity_hierarchy["Black"] = "Black" + ethnicity_hierarchy["White"] = "White" + return ethnicity_hierarchy + + +def load_mock_subjects() -> MockSubjectsLoader: + + mock_data = MockSubjectsData(FILENAME=Path("../../data/mocksubjects.csv"), + COLUMNS_TYPES={"ethnicity": str, "salary": float, "diagnosis": int}, + FEATURES_DROP_NAMES=["NHSno", "given_name", + "surname", "dob"] + ["preventative_treatment", + "gender", "education", "mutation_status"], + NORMALIZED_COLUMNS=["salary"]) + + ds = MockSubjectsLoader(mock_data) + + assert ds.n_columns == 3, "Invalid number of columns {0} not equal to 3".format(ds.n_columns) + + return ds + + +def load_discrete_env() -> DiscreteStateEnvironment: + + mock_ds = load_mock_subjects() + + # create bins for the salary generalization + unique_salary = mock_ds.get_column_unique_values(col_name="salary") + unique_salary.sort() + + # modify slightly the max value because + # we get out of bounds for the maximum salary + bins = np.linspace(unique_salary[0], unique_salary[-1] + 1, N_BINS) + + action_space = ActionSpace(n=5) + action_space.add_many(ActionIdentity(column_name="ethnicity"), + ActionStringGeneralize(column_name="ethnicity", + generalization_table=get_ethinicity_hierarchy()), + ActionIdentity(column_name="salary"), + ActionNumericBinGeneralize(column_name="salary", generalization_table=bins), + ActionIdentity(column_name="diagnosis")) + + action_space.shuffle() + + env = DiscreteStateEnvironment.from_options(data_set=mock_ds, + action_space=action_space, + distortion_calculator=DistortionCalculator( + numeric_column_distortion_metric_type=NumericDistanceType.L2_AVG, + string_column_distortion_metric_type=StringDistanceType.COSINE_NORMALIZE, + dataset_distortion_type=DistortionCalculationType.SUM), + reward_manager=RewardManager(bounds=(MIN_DISTORTION, MAX_DISTORTION), + out_of_max_bound_reward=OUT_OF_MAX_BOUND_REWARD, + out_of_min_bound_reward=OUT_OF_MIN_BOUND_REWARD, + in_bounds_reward=IN_BOUNDS_REWARD), + gamma=GAMMA, + reward_factor=REWARD_FACTOR, + punish_factor=PUNISH_FACTOR, + min_distortion=MIN_DISTORTION, max_distortion=MAX_DISTORTION, + n_rounds_below_min_distortion=N_ROUNDS_BELOW_MIN_DISTORTION, + distorted_set_path=Path(SAVE_DISTORTED_SETS_DIR), + n_states=N_LAYERS * Layer.n_tiles_per_action(N_BINS, + mock_ds.n_columns)) + + return env + + +if __name__ == '__main__': + + # set the seed for random engine + random.seed(42) + + discrete_env = load_discrete_env() + tiled_env_config = TiledEnvConfig(n_layers=N_LAYERS, n_bins=N_BINS, + env=discrete_env, + column_ranges={"ethnicity": [0.0, 1.0], + "salary": [0.0, 1.0], + "diagnosis": [0.0, 1.0]}) + tiled_env = TiledEnv(tiled_env_config) + tiled_env.create_tiles() + + configuration = {"n_episodes": N_EPISODES, "output_msg_frequency": OUTPUT_MSG_FREQUENCY} + + agent_config = SemiGradSARSAConfig(gamma=GAMMA, alpha=ALPHA, n_itrs_per_episode=N_ITRS_PER_EPISODE, + policy=EpsilonGreedyQEstimator(EpsilonGreedyQEstimatorConfig(eps=EPS, n_actions=tiled_env.n_actions, + decay_op=EPSILON_DECAY_OPTION, + epsilon_decay_factor=EPSILON_DECAY_FACTOR, + env=tiled_env, gamma=GAMMA, alpha=ALPHA))) + agent = SemiGradSARSA(agent_config) + + # create a trainer to train the Qlearning agent + trainer = Trainer(env=tiled_env, agent=agent, configuration=configuration) + trainer.train() diff --git a/src/policies/epsilon_greedy_policy.py b/src/policies/epsilon_greedy_policy.py index 8725c84..7766cfa 100644 --- a/src/policies/epsilon_greedy_policy.py +++ b/src/policies/epsilon_greedy_policy.py @@ -66,7 +66,7 @@ def __init__(self, eps: float, n_actions: int, self.user_defined_decrease_method: UserDefinedDecreaseMethod = user_defined_decrease_method def __str__(self) -> str: - return self.__name__ + return "EpsilonGreedyPolicy" def __call__(self, q_table: QTable, state: Any) -> int: """ diff --git a/src/spaces/discrete_state_environment.py b/src/spaces/discrete_state_environment.py index 0eca1f1..fb612bf 100644 --- a/src/spaces/discrete_state_environment.py +++ b/src/spaces/discrete_state_environment.py @@ -66,7 +66,7 @@ def from_options(cls, *, data_set: DataSet, action_space: ActionSpace, return cls(env_config=config) @classmethod - def from_dataset(cls, data_set: DataSet, *, action_space: ActionSpace=None, + def from_dataset(cls, data_set: DataSet, *, action_space: ActionSpace = None, reward_manager: RewardManager = None, distortion_calculator: DistortionCalculator = None): config = DiscreteEnvConfig(data_set=data_set, action_space=action_space, reward_manager=reward_manager, @@ -115,6 +115,19 @@ def column_distortions(self) -> dict: return self.column_distances def get_action(self, aidx: int) -> ActionBase: + """Returns the action if the global aidx index + + Parameters + ---------- + + aidx: The index of the action to return + + Returns + ------- + + An instance of ActionBase + + """ return self.config.action_space[aidx] def save_current_dataset(self, episode_index: int, save_index: bool = False) -> None: @@ -257,6 +270,7 @@ def step(self, action: ActionBase) -> TimeStep: """ # apply the action and update distoration # and column count + self.apply_action(action=action) # calculate the distortion of the dataset diff --git a/src/spaces/tiled_environment.py b/src/spaces/tiled_environment.py index d12af9b..5f1a6db 100644 --- a/src/spaces/tiled_environment.py +++ b/src/spaces/tiled_environment.py @@ -107,6 +107,11 @@ class Layer(object): """Helper class to represent a layer of tiling """ + + @staticmethod + def n_tiles_per_action(n_bins: int, n_columns: int) -> int: + return n_bins ** n_columns + def __init__(self, column_bins, n_bins: int, n_actions: int, start_index: int, end_index: int): self.column_bins = column_bins @@ -321,7 +326,6 @@ def __init__(self, config: TiledEnvConfig) -> None: # This assigns a unique index to each tile up to max_size tiles. self._validate() self._create_column_scales() - #self.column_bins = {} @property def action_space(self): @@ -333,13 +337,21 @@ def n_actions(self) -> int: @property def n_states(self) -> int: - return self.env.n_states + """Returns the total number of states in the environment + + Returns + ------- + + The total number of states in the environment + + """ + return self.n_layers * Layer.n_tiles_per_action(self.n_bins, len(self.column_ranges)) @property def config(self) -> Config: return self.env.config - def step(self, action: ActionBase) -> TimeStep: + def step(self, action: ActionBase, **options) -> TimeStep: """Execute the action in the environment and return a new state for observation @@ -366,6 +378,13 @@ def step(self, action: ActionBase) -> TimeStep: state.column_distortions = self.env.column_distortions time_step = copy_time_step(time_step=raw_time_step, **{"observation": state}) + + if "tiled_state" in options and options['tiled_state'] is True: + + # we want to put the state into the tiles + tiled_state = self.featurize_raw_state(state) + time_step = copy_time_step(time_step=time_step, **{"observation": tiled_state}) + return time_step def reset(self, **options) -> TimeStep: @@ -399,16 +418,37 @@ def reset(self, **options) -> TimeStep: time_step = copy_time_step(time_step=raw_time_step, **{"observation": state}) - # we want to put the state into the tiles - tiled_state = self._featurize_raw_state(state) - time_step = copy_time_step(time_step=time_step, **{"observation": tiled_state}) + if "tiled_state" in options and options['tiled_state'] is True: + + # we want to put the state into the tiles + tiled_state = self.featurize_raw_state(state) + time_step = copy_time_step(time_step=time_step, **{"observation": tiled_state}) + return time_step + def get_state_action_tile_matrix(self, state: TiledState) -> np.array: + """ Transform the TiledState vector to a numpy 2D array + + Parameters + ---------- + + state: The tiled state-action vector + + Returns + ------- + + A 2d numpy array + """ + + return state.reshape(self.n_layers, self.n_actions, Layer.n_tiles_per_action(n_bins=self.n_bins, + n_columns=len(self.column_ranges))) + def get_action(self, aidx: int) -> ActionBase: """Returns the action that corresponds to the given index Parameters ---------- + aidx: The index of the action to return Returns @@ -510,6 +550,7 @@ def apply_action(self, action: ActionBase) -> None: Parameters ---------- + action: The action to apply Returns @@ -530,55 +571,58 @@ def total_current_distortion(self) -> float: """ return self.env.total_current_distortion() - def get_scaled_state(self, state: State) -> list: - """Scales the state components and returns the - scaled state + def featurize_state_action(self, state: RawState, action: ActionBase) -> TiledState: + """Get a list of Tiles for the given state and action Parameters ---------- - state: The state to scale + state: The environment state observed + action: The action Returns ------- - A list of scaled state values + A list of tiles """ - scaled_state_vals = [] - for name in state: - scaled_state_vals.append(state[name] * self.column_scales[name]) - return scaled_state_vals + tiled_state = np.zeros(self.n_layers * self.n_actions * self.n_bins ** (len(self.column_ranges))) - def featurize_state_action(self, state: State, action: ActionBase) -> List[Tile]: - """Get a list of Tiles for the given state and action + found = False + for layer in range(self.n_layers): + global_idx = self.tiles[layer].get_global_tile_index(raw_state=state, action=action) + if global_idx != INVALID_ID: + tiled_state[global_idx] = 1.0 + + return tiled_state + + def featurize_raw_state(self, state: RawState) -> TiledState: + """Returns the tiled state vector given the vector + of column distortions Parameters ---------- - state: The environment state observed - action: The action + state Returns ------- - A list of tiles - """ - scaled_state = self.get_scaled_state(state) - featurized = tiles(self.iht, self.n_layers, scaled_state, [action]) - return featurized - - def _featurize_raw_state(self, state: RawState) -> TiledState: - tiled_state = np.zeros(self.n_layers * self.n_actions * self.n_bins ** (len(self.column_ranges))) + found = False for layer in range(self.n_layers): + + #if found: + # break + for action in range(self.n_actions): global_idx = self.tiles[layer].get_global_tile_index(raw_state=state, action=action) if global_idx != INVALID_ID: tiled_state[global_idx] = 1.0 - break + #found = True + #break return tiled_state diff --git a/tests/test_epsilon_greedy_q_estimator.py b/tests/test_epsilon_greedy_q_estimator.py index 8200010..3412d07 100644 --- a/tests/test_epsilon_greedy_q_estimator.py +++ b/tests/test_epsilon_greedy_q_estimator.py @@ -1,6 +1,8 @@ import unittest +import pytest from src.algorithms.epsilon_greedy_q_estimator import EpsilonGreedyQEstimator, EpsilonGreedyQEstimatorConfig +from src.exceptions.exceptions import InvalidParamValue class TestEpsilonGreedyQEstimator(unittest.TestCase): @@ -15,11 +17,12 @@ def test_on_state(self): eps_q_estimator_config = EpsilonGreedyQEstimatorConfig() eps_q_estimator = EpsilonGreedyQEstimator(eps_q_estimator_config) + def test_q_hat_value_raise_InvalidParamValue(self): + eps_q_estimator_config = EpsilonGreedyQEstimatorConfig() + eps_q_estimator = EpsilonGreedyQEstimator(eps_q_estimator_config) - - - - + with pytest.raises(InvalidParamValue) as e: + eps_q_estimator.q_hat_value(None) if __name__ == '__main__':