Skip to content

Commit

Permalink
Moved wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
cnheider committed Jul 11, 2019
1 parent 92f4879 commit d1f26e6
Show file tree
Hide file tree
Showing 32 changed files with 356 additions and 119 deletions.
3 changes: 3 additions & 0 deletions neodroid/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def coroutine_generator(self):
'''
return self

def render(self):
pass

@staticmethod
def seed(seed):
'''
Expand Down
6 changes: 4 additions & 2 deletions neodroid/environments/neodroid_environments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
from typing import Dict

from neodroid.factories.inference import maybe_infer_motion_reaction
from neodroid.interfaces.environment_models import EnvironmentDescription, EnvironmentSnapshot
Expand Down Expand Up @@ -114,7 +115,7 @@ def react(
normalise=False,
on_reaction_sent_callback=None,
on_step_done_callback=None,
**kwargs) -> EnvironmentSnapshot:
**kwargs) -> Dict[str, EnvironmentSnapshot]:
'''
:param input_reactions:
Expand Down Expand Up @@ -145,7 +146,8 @@ def react(
elif not isinstance(input_reactions, M.Reaction):
input_reaction = maybe_infer_motion_reaction(input_reactions=input_reactions,
normalise=normalise,
description=self._description
description=self._description,
action_space=self.action_space
)
input_reactions = [input_reaction]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@


def connect(ip='localhost', port=6969, *args, **kwargs):
return SingleEnvironmentWrapper(ip=ip, port=port, connect_to_running=True, *args, **kwargs)
return SingleEnvironment(ip=ip, port=port, connect_to_running=True, *args, **kwargs)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from neodroid.wrappers.single_environment_wrapper import SingleEnvironmentWrapper
from neodroid.environments.wrappers import SingleEnvironment

__author__ = 'cnheider'


class NeodroidALEWrapper(SingleEnvironmentWrapper):
class NeodroidALEWrapper(SingleEnvironment):

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import random
from typing import Any

from neodroid.environments.wrappers.single_environment import SingleEnvironment
from neodroid.interfaces.environment_models import Reaction, ReactionParameters
from neodroid.utilities.transformations.encodings import signed_ternary_encoding
from neodroid.wrappers.single_environment_wrapper import SingleEnvironmentWrapper

__author__ = 'cnheider'

import numpy as np


class NeodroidCurriculumWrapper(SingleEnvironmentWrapper):
class NeodroidCurriculumWrapper(SingleEnvironment):

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# -*- coding: utf-8 -*-
from warnings import warn

from neodroid.wrappers.single_environment_wrapper import SingleEnvironmentWrapper
from neodroid.environments.wrappers.single_environment import SingleEnvironment

__author__ = 'cnheider'


class NeodroidFormalWrapper(SingleEnvironmentWrapper):
class NeodroidFormalWrapper(SingleEnvironment):

def __next__(self):
if not self._is_connected_to_server:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def make(environment_name, **kwargs):
return NeodroidVectorGymWrapper(environment_name=environment_name, **kwargs)
return NeodroidVectorGymEnvironment(environment_name=environment_name, **kwargs)


def seed(seed):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Any

from neodroid.utilities.transformations.encodings import signed_ternary_encoding
from neodroid.wrappers import NeodroidVectorGymWrapper
from neodroid.environments.wrappers import NeodroidVectorGymEnvironment

__author__ = 'cnheider'


class DiscreteActionEncodingWrapper(NeodroidVectorGymWrapper):
class DiscreteActionEncodingWrapper(NeodroidVectorGymEnvironment):

def step(self, action: int = 0, **kwargs) -> Any:
ternary_action = signed_ternary_encoding(size=self.action_space.num_discrete_actions // 3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# -*- coding: utf-8 -*-
from warnings import warn

from gym import Env

from neodroid.environments.wrappers.single_environment import SingleEnvironment
from neodroid.interfaces.environment_models import EnvironmentSnapshot
from neodroid.interfaces.spaces import ActionSpace, ObservationSpace, Range
from neodroid.wrappers.single_environment_wrapper import SingleEnvironmentWrapper
from trolls import SubProcessEnvironments

__author__ = 'cnheider'

Expand All @@ -15,8 +18,8 @@
# warn(f"This module is deprecated in version {__version__}", DeprecationWarning)


class NeodroidGymWrapper(SingleEnvironmentWrapper,
gym.Env):
class NeodroidGymEnvironment(SingleEnvironment,
gym.Env):

def step(self, action=None, *args, **kwargs):
'''
Expand Down Expand Up @@ -75,8 +78,8 @@ def spec(self):
return None


class NeodroidVectorGymWrapper(SingleEnvironmentWrapper,
gym.Env):
class NeodroidVectorGymEnvironment(SingleEnvironment,
gym.Env):

def step(self, action=None, *args, **kwargs):
'''
Expand Down Expand Up @@ -123,8 +126,8 @@ def spec(self):
return None


class NeodroidWrapper:
def __init__(self, env):
class NeodroidGymWrapper:
def __init__(self, env:Env):
'''
:param env:
Expand All @@ -140,9 +143,12 @@ def observation_space(self):
_input_shape = None

if len(self._env.observation_space.shape) >= 1:
_input_shape = self._env.observation_space
_input_shape = ObservationSpace([Range(decimal_granularity=2) for _ in range(
self._env.observation_space.shape[0])])
else:
_output_shape = ObservationSpace([Range(min_value=0, max_value=self._env.observation_space.n)])
_input_shape = ObservationSpace([Range(min_value=0,
max_value=self._env.observation_space.n,
decimal_granularity=0)])

return _input_shape

Expand All @@ -155,9 +161,12 @@ def action_space(self):
_output_shape = None

if len(self._env.action_space.shape) >= 1:
_output_shape = self._env.action_space
_output_shape = ActionSpace([Range(decimal_granularity=2) for _ in range(
self._env.action_space.shape[0])])
else:
_output_shape = ActionSpace([Range(min_value=0, max_value=self._env.action_space.n)])
_output_shape = ActionSpace([Range(min_value=0,
max_value=self._env.action_space.n,
decimal_granularity=0)])

return _output_shape

Expand Down Expand Up @@ -187,4 +196,6 @@ def __getattr__(self, item):


if __name__ == '__main__':
NeodroidVectorGymWrapper()
env = NeodroidGymWrapper(gym.make('CartPole-v1'))
print(env.observation_space)
print(env.action_space)
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from neodroid.interfaces.neodroid_standard_modules.neodroid_camera_extraction import \
(extract_camera_observation, extract_neodroid_camera,
)
from neodroid.wrappers.single_environment_wrapper import SingleEnvironmentWrapper
from neodroid.environments.wrappers import SingleEnvironment

__author__ = 'cnheider'


class ObservationWrapper(SingleEnvironmentWrapper):
class ObservationWrapper(SingleEnvironment):

def __next__(self):
if not self._is_connected_to_server:
Expand Down Expand Up @@ -47,7 +47,7 @@ def quit(self, *args, **kwargs):
return self.close(*args, **kwargs)


class CameraObservationWrapper(SingleEnvironmentWrapper):
class CameraObservationWrapper(SingleEnvironment):

def __init__(self, auto_reset=True, **kwargs):
super().__init__(**kwargs)
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from neodroid.wrappers.single_environment_wrapper import SingleEnvironmentWrapper
from neodroid.environments.wrappers import SingleEnvironment

__author__ = 'cnheider'


class FlaskWrapper(SingleEnvironmentWrapper):
class FlaskWrapper(SingleEnvironment):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__author__ = 'cnheider'


class SingleEnvironmentWrapper(NeodroidEnvironment):
class SingleEnvironment(NeodroidEnvironment):

def __next__(self):
if not self._is_connected_to_server:
Expand All @@ -26,7 +26,8 @@ def react(self,
if not isinstance(input_reaction, Reaction):
input_reaction = maybe_infer_motion_reaction(input_reactions=input_reaction,
normalise=normalise,
description=self._description
description=self._description,
action_space=self.action_space
)
if parameters is not None:
input_reaction.parameters = parameters
Expand Down Expand Up @@ -93,8 +94,8 @@ def sensor(self, name, *args, **kwargs):
help='Connect to already running environment instead of starting another instance')
proc_args = parser.parse_args()

env = SingleEnvironmentWrapper(environment_name=proc_args.ENVIRONMENT_NAME,
connect_to_running=proc_args.CONNECT_TO_RUNNING)
env = SingleEnvironment(environment_name=proc_args.ENVIRONMENT_NAME,
connect_to_running=proc_args.CONNECT_TO_RUNNING)

observation_session = tqdm(env, leave=False)
for environment_state in observation_session:
Expand Down

0 comments on commit d1f26e6

Please sign in to comment.