Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/algorithms/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
env.step(action=action)
total_dist = env.total_current_distortion()

def train(self, env: Env, **options) -> tuple:
def on_episode(self, env: Env, **options) -> tuple:

# episode score
episode_score = 0
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train(self):
ignore = self.env.reset()

# train for a number of iterations
episode_score, total_distortion, n_itrs = self.agent.train(self.env)
episode_score, total_distortion, n_itrs = self.agent.on_episode(self.env)

print("{0} Episode score={1}, episode total distortion {2}".format(INFO, episode_score, total_distortion / n_itrs))

Expand Down
17 changes: 17 additions & 0 deletions src/spaces/column_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Simple enumeration of column types.
This is similar to the ARX software.
See the ARX documentation at:
https://arx.deidentifier.org/wp-content/uploads/javadoc/current/api/org/deidentifier/arx/AttributeType.html
"""

import enum


class ColumnType(enum.IntEnum):

INVALID_TYPE = 0
IDENTIFYING_ATTRIBUTE = 1
SENSITIVE_ATTRIBUTE = 2
INSENSITIVE_ATTRIBUTE = 3
QUASI_IDENTIFYING_ATTRIBUTE = 4
60 changes: 6 additions & 54 deletions src/spaces/discrete_state_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,69 +4,19 @@
"""

import copy
import enum
import numpy as np
from pathlib import Path
import pandas as pd
import torch
from typing import NamedTuple, Generic, Optional, TypeVar, List
from typing import TypeVar, List
import multiprocessing as mp

from src.spaces.actions import ActionBase, ActionType
from src.utils.string_distance_calculator import StringDistanceType, TextDistanceCalculator
from src.utils.numeric_distance_type import NumericDistanceType
from src.utils.numeric_distance_calculator import NumericDistanceCalculator
from src.spaces.time_step import TimeStep, StepType

DataSet = TypeVar("DataSet")
RewardManager = TypeVar("RewardManager")
ActionSpace = TypeVar("ActionSpace")
DistortionCalculator = TypeVar('DistortionCalculator')

_Reward = TypeVar('_Reward')
_Discount = TypeVar('_Discount')
_Observation = TypeVar('_Observation')


class StepType(enum.IntEnum):
"""
Defines the status of a `TimeStep` within a sequence.
"""

# Denotes the first `TimeStep` in a sequence.
FIRST = 0

# Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
MID = 1

# Denotes the last `TimeStep` in a sequence.
LAST = 2

def first(self) -> bool:
return self is StepType.FIRST

def mid(self) -> bool:
return self is StepType.MID

def last(self) -> bool:
return self is StepType.LAST


class TimeStep(NamedTuple, Generic[_Reward, _Discount, _Observation]):
step_type: StepType
info: dict
reward: Optional[_Reward]
discount: Optional[_Discount]
observation: _Observation

def first(self) -> bool:
return self.step_type == StepType.FIRST

def mid(self) -> bool:
return self.step_type == StepType.MID

def last(self) -> bool:
return self.step_type == StepType.LAST


class DiscreteEnvConfig(object):
"""
Expand All @@ -79,8 +29,6 @@ def __init__(self) -> None:
self.reward_manager: RewardManager = None
self.average_distortion_constraint: float = 0.0
self.gamma: float = 0.99
# self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID
# self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
self.n_states: int = 10
self.min_distortion: float = 0.4
self.max_distortion: float = 0.7
Expand Down Expand Up @@ -115,6 +63,10 @@ def __init__(self, env_config: DiscreteEnvConfig) -> None:
self.column_visits = {}
self.create_bins()

@property
def columns_attribute_types(self) -> dict:
return self.config.data_set.columns_attribute_types

@property
def action_space(self):
return self.config.action_space
Expand Down
55 changes: 55 additions & 0 deletions src/spaces/time_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""

"""

import enum
from typing import NamedTuple, Generic, Optional, TypeVar

_Reward = TypeVar('_Reward')
_Discount = TypeVar('_Discount')
_Observation = TypeVar('_Observation')


class StepType(enum.IntEnum):
"""
Defines the status of a `TimeStep` within a sequence.
"""

# Denotes the first `TimeStep` in a sequence.
FIRST = 0

# Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
MID = 1

# Denotes the last `TimeStep` in a sequence.
LAST = 2

def first(self) -> bool:
return self is StepType.FIRST

def mid(self) -> bool:
return self is StepType.MID

def last(self) -> bool:
return self is StepType.LAST


class TimeStep(NamedTuple, Generic[_Reward, _Discount, _Observation]):
step_type: StepType
info: dict
reward: Optional[_Reward]
discount: Optional[_Discount]
observation: _Observation

def first(self) -> bool:
return self.step_type == StepType.FIRST

def mid(self) -> bool:
return self.step_type == StepType.MID

def last(self) -> bool:
return self.step_type == StepType.LAST

@property
def done(self) -> bool:
return self.last()