Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/spaces/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def act(self, **ops) -> None:
:return:
"""

@abc.abstractmethod
def get_maximum_number_of_transforms(self):
"""
Returns the maximum number of transforms that the action applies
:return:
"""


def move_next(iterators: List) -> None:
"""
Expand Down Expand Up @@ -90,6 +97,13 @@ def act(self, **ops):
"""
pass

def get_maximum_number_of_transforms(self):
"""
Returns the maximum number of transforms that the action applies
:return:
"""
return 1


class ActionTransform(ActionBase):

Expand All @@ -106,6 +120,13 @@ def act(self, **ops):
"""
pass

def get_maximum_number_of_transforms(self):
"""
Returns the maximum number of transforms that the action applies
:return:
"""
raise NotImplementedError("Method not implemented")


class ActionSuppress(ActionBase, _WithTable):

Expand Down Expand Up @@ -138,6 +159,21 @@ def act(self, **ops) -> None:
# update the generalization
move_next(iterators=self.iterators)

def get_maximum_number_of_transforms(self):
"""
Returns the maximum number of transforms that the action applies
:return:
"""
max_transform = 0

for item in self.table:
size = len(self.table[item])

if size > max_transform:
max_transform = size

return max_transform


class ActionGeneralize(ActionBase, _WithTable):
"""
Expand Down Expand Up @@ -181,5 +217,21 @@ def act(self, **ops):
def add_generalization(self, key: str, values: HierarchyBase) -> None:
self.table[key] = values

def get_maximum_number_of_transforms(self):
"""
Returns the maximum number of transforms that the action applies
:return:
"""
max_transform = 0

for item in self.table:
size = len(self.table[item])

if size > max_transform:
max_transform = size

return max_transform




60 changes: 45 additions & 15 deletions src/spaces/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

from src.exceptions.exceptions import Error
from src.spaces.actions import ActionBase, ActionType
from src.spaces.state_space import StateSpace, State
from src.utils.string_distance_calculator import DistanceType, TextDistanceCalculator

DataSet = TypeVar("DataSet")
RewardManager = TypeVar("RewardManager")

_Reward = TypeVar('_Reward')
_Discount = TypeVar('_Discount')
Expand Down Expand Up @@ -65,20 +67,37 @@ def last(self) -> bool:
class Environment(object):

def __init__(self, data_set, action_space,
gamma: float, start_column: str, ):
gamma: float, start_column: str, reward_manager: RewardManager):
self.data_set = data_set
self.start_ds = copy.deepcopy(data_set)
self.current_time_step = self.start_ds
self.action_space = action_space
self.gamma = gamma
self.start_column = start_column
self.column_distances = {}
self.state_space = StateSpace()
self.distance_calculator = None
self.reward_manager: RewardManager = reward_manager

# initialize the state space
self.state_space.init_from_environment(env=self)

@property
def n_features(self) -> int:
"""
Returns the number of features in the dataset
:return:
"""
return self.start_ds.n_columns

@property
def feature_names(self) -> list:
"""
Returns the feature names in the dataset
:return:
"""
return self.start_ds.get_columns_names()

@property
def n_examples(self) -> int:
return self.start_ds.n_rows
Expand All @@ -99,6 +118,24 @@ def initialize_text_distances(self, distance_type: DistanceType) -> None:
def sample_action(self) -> ActionBase:
return self.action_space.sample_and_get()

def get_column_as_tensor(self, column_name) -> torch.Tensor:
"""
Returns the column in the dataset as a torch tensor
:param column_name:
:return:
"""
data = {}

if self.start_ds.columns[column_name] == str:

numpy_vals = self.column_distances[column_name]
data[column_name] = numpy_vals
else:
data[column_name] = self.data_set.get_column(col_name=column_name).to_numpy()

target_df = pd.DataFrame(data)
return torch.tensor(target_df.to_numpy(), dtype=torch.float64)

def get_ds_as_tensor(self) -> torch.Tensor:

"""
Expand All @@ -111,7 +148,6 @@ def get_ds_as_tensor(self) -> torch.Tensor:
for col in col_names:

if self.start_ds.columns[col] == str:
#print("col: {0} type {1}".format(col, self.start_ds.get_column_type(col_name=col)))
numpy_vals = self.column_distances[col]
data[col] = numpy_vals
else:
Expand Down Expand Up @@ -195,28 +231,22 @@ def step(self, action: ActionBase) -> TimeStep:
`action` will be ignored.
"""

# apply the action
self.apply_action(action=action)

# if the action is identity don't bother
# doing anything
#if action.action_type == ActionType.IDENTITY:
# return TimeStep(step_type=StepType.MID, reward=0.0,
# observation=self.get_ds_as_tensor().float(), discount=self.gamma)

# apply the transform of the data set
#self.data_set.apply_column_transform(transform=action)
# update the state space
self.state_space.update_state(state_name=action.column_name, status=action.action_type)

# perform the action on the data set
self.prepare_column_states()

# calculate the information leakage and establish the reward
# to return to the agent
reward = self.reward_manager.get_state_reward(self.state_space, action)

return TimeStep(step_type=StepType.MID, reward=0.0,
observation=self.get_ds_as_tensor().float(), discount=self.gamma)



return TimeStep(step_type=StepType.MID, reward=reward,
observation=self.get_column_as_tensor(column_name=action.column_name).float(),
discount=self.gamma)


class MultiprocessEnv(object):
Expand Down
64 changes: 64 additions & 0 deletions src/spaces/state_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Discretized state space
"""

from typing import TypeVar, List
from gym.spaces.discrete import Discrete

from src.exceptions.exceptions import Error

ActionStatus = TypeVar("ActionStatus")
Env = TypeVar("Env")


class State(object):
"""
Describes an environment state
"""
def __init__(self, column_name: str, state_id: int):
self.column_name: str = column_name
self.state_id: int = state_id
self.history: List[ActionStatus] = []

@property
def key(self) -> tuple:
return self.column_name, self.state_id


class StateSpace(Discrete):
"""
The State space is accumulates the discrete states
"""

def __init__(self):
super(StateSpace, self).__init__(n=0)
self.states = {}

def init_from_environment(self, env: Env):
"""
Initialize from environment
:param env:
:return:
"""
names = env.feature_names
for col_name in names:

if col_name in self.states:
raise ValueError("Column {0} already exists".format(col_name))

self.states[col_name] = State(column_name=col_name, state_id=len(self.states))

# set the number of discrete states
self.n = len(self.states)

def add_state(self, state: State):
if state.column_name in self.states:
raise ValueError("Column {0} already exists".format(state.column_name))

self.states[state.column_name] = state

def update_state(self, state_name, status: ActionStatus):
self.states[state_name].history.append(status)

def __len__(self):
return len(self.states)
Loading