Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/algorithms/epsilon_greedy_q_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from src.utils.mixins import WithEstimatorMixin
from src.policies.epsilon_greedy_policy import EpsilonGreedyPolicy, EpsilonGreedyConfig
from src.exceptions.exceptions import InvalidParamValue

StateActionVec = TypeVar('StateActionVec')
State = TypeVar('State')
Expand Down Expand Up @@ -42,6 +43,18 @@ def __init__(self, config: EpsilonGreedyQEstimatorConfig):
self.gamma: float = config.gamma
self.env: Env = config.env
self.weights: np.array = None
self.initialize()

def initialize(self) -> None:
"""Initialize the underlying weights

Returns
-------

None

"""
self.weights: np.array = np.zeros((self.env.n_states * self.env.n_actions))

def q_hat_value(self, state_action_vec: StateActionVec) -> float:
"""Returns the
Expand All @@ -60,6 +73,10 @@ def q_hat_value(self, state_action_vec: StateActionVec) -> float:


"""

if self.weights is None:
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")

return self.weights.dot(state_action_vec)

def update_weights(self, total_reward: float, state_action: Action,
Expand All @@ -81,6 +98,10 @@ def update_weights(self, total_reward: float, state_action: Action,
None

"""

if self.weights is None:
raise InvalidParamValue(param_name="weights", param_value="None. Have you called initialize?")

v1 = self.q_hat_value(state_action_vec=state_action)
v2 = self.q_hat_value(state_action_vec=state_action_)
self.weights += self.alpha / t * (total_reward + self.gamma * v2 - v1) * state_action
Expand All @@ -99,14 +120,18 @@ def on_state(self, state: State) -> Action:
An environment specific Action type
"""

# compute the state values related to
# the given state
# get the approximation of the q-values
# given the state

q_values = []

for action in range(self.env.n_actions):
state_action_vector = self.env.get_state_action_tile(action=action, state=state)
q_values.append(state_action_vector)
for a in range(self.env.n_actions):
tiled_vector = self.env.featurize_state_action(action=a, state=state)
q_values.append(self.q_hat_value(tiled_vector))

# choose an action at the current state
action = self.eps_policy(q_values, state)

# this is an integer get the ActionBase instead
action = self.env.get_action(action)
return action
67 changes: 61 additions & 6 deletions src/algorithms/semi_gradient_sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from src.utils.mixins import WithMaxActionMixin, WithQTableMixinBase, WithEstimatorMixin
from src.utils.episode_info import EpisodeInfo
from src.spaces.time_step import TimeStep
from src.exceptions.exceptions import InvalidParamValue


Policy = TypeVar('Policy')
Env = TypeVar('Env')
State = TypeVar('State')
Expand All @@ -38,11 +40,16 @@ class SemiGradSARSA(object):
def __init__(self, config: SemiGradSARSAConfig) -> None:
self.config: SemiGradSARSAConfig = config

@property
def name(self) -> str:
return "Semi-Grad SARSA"

def actions_before_training(self, env: Env, **options) -> None:
"""Specify any actions necessary before training begins

Parameters
----------

env: The environment to train on
options: Any key-value options passed by the client

Expand All @@ -60,27 +67,74 @@ def actions_before_training(self, env: Env, **options) -> None:
self.q_table[state, action] = 0.0
"""

def on_episode(self, env: Env, **options) -> EpisodeInfo:
def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None:
"""Any actions to perform before the episode begins

Parameters
----------

env: The instance of the training environment
episode_idx: The training episode index
options: Any keyword options passed by the client code

Returns
-------

None

"""

def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> None:
"""Any actions after the training episode ends

Parameters
----------

env: The training environment
episode_idx: The training episode index
options: Any options passed by the client code

Returns
-------

None
"""

def on_episode(self, env: Env, episode_idx: int, **options) -> EpisodeInfo:
"""Train the algorithm on the episode

Parameters
----------

env: The environment to train on
options: Any keyword based options passed by the client code

Returns
-------

An instance of EpisodeInfo
"""

episode_reward = 0.0
episode_n_itrs = 0

# reset the environment
time_step = env.reset()
time_step = env.reset(**{"tiled_state": False})

# select a state
state: State = time_step.observation

#choose an action using the policy
action: Action = self.config.policy(state)
action: Action = self.config.policy.on_state(state)

for itr in range(self.config.n_itrs_per_episode):

# take action and observe reward and next_state
time_step = env.step(action)
reward: float = 0.0
time_step: TimeStep = env.step(action, **{"tiled_state": False})

reward: float = time_step.reward
episode_reward += reward
next_state: State = None
next_state: State = time_step.observation

# if next_state is terminal i.e. the done flag
# is set. then update the weights
Expand Down Expand Up @@ -109,6 +163,7 @@ def _weights_update_episode_done(self, state: State, reward: float,

Parameters
----------

state: The current state
reward: The reward to use
action: The action we took at state
Expand Down
33 changes: 25 additions & 8 deletions src/algorithms/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,41 @@ def actions_before_training(self) -> None:
self.iterations_per_episode = []
self.agent.actions_before_training(self.env)

def actions_before_episode_begins(self, **options) -> None:
def actions_before_episode_begins(self, env: Env, episode_idx: int, **options) -> None:
"""Perform any actions necessary before the training begins

Parameters
----------
env: The environment to train on
episode_idx: The training episode index
options: Any options passed by the client code

Returns
-------

None

"""
self.agent.actions_before_episode_begins(**options)
self.agent.actions_before_episode_begins(env, episode_idx, **options)

def actions_after_episode_ends(self, env: Env, episode_idx: int, **options) -> None:
"""Any actions after the training episode ends

Parameters
----------

env: The environment to train on
episode_idx: The training episode index
options: Any options passed by the client code

Returns
-------

def actions_after_episode_ends(self, **options):
self.agent.actions_after_episode_ends(**options)
None
"""
self.agent.actions_after_episode_ends(env, episode_idx, **options)

if options["episode_idx"] % self.configuration['output_msg_frequency'] == 0:
if episode_idx % self.configuration['output_msg_frequency'] == 0:
if self.env.config.distorted_set_path is not None:
self.env.save_current_dataset(options["episode_idx"])

Expand All @@ -93,10 +110,10 @@ def train(self):
# reset the environment
#ignore = self.env.reset()

self.actions_before_episode_begins(**{"env": self.env})
self.actions_before_episode_begins(self.env, episode,)
# train for a number of iterations
#episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)
episode_info: EpisodeInfo = self.agent.on_episode(self.env)
episode_info: EpisodeInfo = self.agent.on_episode(self.env, episode)

print("{0} Episode score={1}, episode total avg distortion {2}".format(INFO, episode_info.episode_score,
episode_info.total_distortion / episode_info.info["n_iterations"]))
Expand All @@ -107,6 +124,6 @@ def train(self):
self.iterations_per_episode.append(episode_info.info["n_iterations"])
self.total_rewards[episode] = episode_info.episode_score
self.total_distortions.append(episode_info.total_distortion)
self.actions_after_episode_ends(**{"episode_idx": episode})
self.actions_after_episode_ends(self.env, episode, **{})

print("{0} Training finished for agent {1}".format(INFO, self.agent.name))
6 changes: 4 additions & 2 deletions src/examples/qlearning_three_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def get_ethinicity_hierarchy():
OUTPUT_MSG_FREQUENCY = 100
N_ROUNDS_BELOW_MIN_DISTORTION = 10
SAVE_DISTORTED_SETS_DIR = "/home/alex/qi3/drl_anonymity/src/examples/q_learn_distorted_sets/distorted_set"
REWARD_FACTOR = 0.95
PUNISH_FACTOR = 2.0

# specify the columns to drop
drop_columns = MockSubjectsLoader.FEATURES_DROP_NAMES + ["preventative_treatment", "gender",
Expand Down Expand Up @@ -144,8 +146,8 @@ def get_ethinicity_hierarchy():
numeric_column_distortion_metric_type=NumericDistanceType.L2_AVG,
string_column_distortion_metric_type=StringDistanceType.COSINE_NORMALIZE,
dataset_distortion_type=DistortionCalculationType.SUM)
env_config.reward_factor = 0.95
env_config.punish_factor = 2.0
env_config.reward_factor = REWARD_FACTOR #0.95
env_config.punish_factor = PUNISH_FACTOR #2.0

# create the environment
env = DiscreteStateEnvironment(env_config=env_config)
Expand Down
Loading