From 020f07bde565e7030864518afcb6804f7f00ba0d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Feb 2022 13:23:43 +0000 Subject: [PATCH 1/4] #40 Rename train to on_episode --- src/algorithms/q_learning.py | 2 +- src/algorithms/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algorithms/q_learning.py b/src/algorithms/q_learning.py index b011e12..7132dfc 100644 --- a/src/algorithms/q_learning.py +++ b/src/algorithms/q_learning.py @@ -89,7 +89,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None: env.step(action=action) total_dist = env.total_current_distortion() - def train(self, env: Env, **options) -> tuple: + def on_episode(self, env: Env, **options) -> tuple: # episode score episode_score = 0 diff --git a/src/algorithms/trainer.py b/src/algorithms/trainer.py index 549a0d8..cbbb01e 100644 --- a/src/algorithms/trainer.py +++ b/src/algorithms/trainer.py @@ -74,7 +74,7 @@ def train(self): ignore = self.env.reset() # train for a number of iterations - episode_score, total_distortion, n_itrs = self.agent.train(self.env) + episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env) print("{0} Episode score={1}, episode total distortion {2}".format(INFO, episode_score, total_distortion / n_itrs)) From 91f780fd8bea33a0aaffc447209b5d8185127dfa Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Feb 2022 13:24:26 +0000 Subject: [PATCH 2/4] #41 Refactor TimeStep class to its own module --- src/spaces/time_step.py | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 src/spaces/time_step.py diff --git a/src/spaces/time_step.py b/src/spaces/time_step.py new file mode 100644 index 0000000..ebce077 --- /dev/null +++ b/src/spaces/time_step.py @@ -0,0 +1,55 @@ +""" + +""" + +import enum +from typing import NamedTuple, Generic, Optional, TypeVar + +_Reward = TypeVar('_Reward') +_Discount = TypeVar('_Discount') +_Observation = TypeVar('_Observation') + + +class StepType(enum.IntEnum): + """ + Defines the status of a `TimeStep` within a sequence. + """ + + # Denotes the first `TimeStep` in a sequence. + FIRST = 0 + + # Denotes any `TimeStep` in a sequence that is not FIRST or LAST. + MID = 1 + + # Denotes the last `TimeStep` in a sequence. + LAST = 2 + + def first(self) -> bool: + return self is StepType.FIRST + + def mid(self) -> bool: + return self is StepType.MID + + def last(self) -> bool: + return self is StepType.LAST + + +class TimeStep(NamedTuple, Generic[_Reward, _Discount, _Observation]): + step_type: StepType + info: dict + reward: Optional[_Reward] + discount: Optional[_Discount] + observation: _Observation + + def first(self) -> bool: + return self.step_type == StepType.FIRST + + def mid(self) -> bool: + return self.step_type == StepType.MID + + def last(self) -> bool: + return self.step_type == StepType.LAST + + @property + def done(self) -> bool: + return self.last() \ No newline at end of file From 97a49d337b582948062fd0952b8805ee73f26dad Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Feb 2022 13:25:06 +0000 Subject: [PATCH 3/4] #38 Add column enumeration --- src/spaces/column_type.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 src/spaces/column_type.py diff --git a/src/spaces/column_type.py b/src/spaces/column_type.py new file mode 100644 index 0000000..d07aaa2 --- /dev/null +++ b/src/spaces/column_type.py @@ -0,0 +1,17 @@ +""" +Simple enumeration of column types. +This is similar to the ARX software. +See the ARX documentation at: +https://arx.deidentifier.org/wp-content/uploads/javadoc/current/api/org/deidentifier/arx/AttributeType.html +""" + +import enum + + +class ColumnType(enum.IntEnum): + + INVALID_TYPE = 0 + IDENTIFYING_ATTRIBUTE = 1 + SENSITIVE_ATTRIBUTE = 2 + INSENSITIVE_ATTRIBUTE = 3 + QUASI_IDENTIFYING_ATTRIBUTE = 4 From a318ec39ba3a53d70f973ca4f6b6e0d4d68eb625 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Feb 2022 13:25:46 +0000 Subject: [PATCH 4/4] 41 Refactor TimeStep --- src/spaces/discrete_state_environment.py | 60 +++--------------------- 1 file changed, 6 insertions(+), 54 deletions(-) diff --git a/src/spaces/discrete_state_environment.py b/src/spaces/discrete_state_environment.py index b114002..c2a7a59 100644 --- a/src/spaces/discrete_state_environment.py +++ b/src/spaces/discrete_state_environment.py @@ -4,69 +4,19 @@ """ import copy -import enum import numpy as np from pathlib import Path -import pandas as pd -import torch -from typing import NamedTuple, Generic, Optional, TypeVar, List +from typing import TypeVar, List import multiprocessing as mp from src.spaces.actions import ActionBase, ActionType -from src.utils.string_distance_calculator import StringDistanceType, TextDistanceCalculator -from src.utils.numeric_distance_type import NumericDistanceType -from src.utils.numeric_distance_calculator import NumericDistanceCalculator +from src.spaces.time_step import TimeStep, StepType DataSet = TypeVar("DataSet") RewardManager = TypeVar("RewardManager") ActionSpace = TypeVar("ActionSpace") DistortionCalculator = TypeVar('DistortionCalculator') -_Reward = TypeVar('_Reward') -_Discount = TypeVar('_Discount') -_Observation = TypeVar('_Observation') - - -class StepType(enum.IntEnum): - """ - Defines the status of a `TimeStep` within a sequence. - """ - - # Denotes the first `TimeStep` in a sequence. - FIRST = 0 - - # Denotes any `TimeStep` in a sequence that is not FIRST or LAST. - MID = 1 - - # Denotes the last `TimeStep` in a sequence. - LAST = 2 - - def first(self) -> bool: - return self is StepType.FIRST - - def mid(self) -> bool: - return self is StepType.MID - - def last(self) -> bool: - return self is StepType.LAST - - -class TimeStep(NamedTuple, Generic[_Reward, _Discount, _Observation]): - step_type: StepType - info: dict - reward: Optional[_Reward] - discount: Optional[_Discount] - observation: _Observation - - def first(self) -> bool: - return self.step_type == StepType.FIRST - - def mid(self) -> bool: - return self.step_type == StepType.MID - - def last(self) -> bool: - return self.step_type == StepType.LAST - class DiscreteEnvConfig(object): """ @@ -79,8 +29,6 @@ def __init__(self) -> None: self.reward_manager: RewardManager = None self.average_distortion_constraint: float = 0.0 self.gamma: float = 0.99 - # self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID - # self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID self.n_states: int = 10 self.min_distortion: float = 0.4 self.max_distortion: float = 0.7 @@ -115,6 +63,10 @@ def __init__(self, env_config: DiscreteEnvConfig) -> None: self.column_visits = {} self.create_bins() + @property + def columns_attribute_types(self) -> dict: + return self.config.data_set.columns_attribute_types + @property def action_space(self): return self.config.action_space