Skip to content

Commit

Permalink
Curriculum training refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
cnheider committed Mar 31, 2019
1 parent b0e17e7 commit 8db32cb
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 79 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Expand Up @@ -18,8 +18,6 @@ venv/

environments/
*.log
.idea

# Byte-compiled / optimized / DLL files
*$py.class

Expand Down
23 changes: 14 additions & 9 deletions neodroid/messaging/message_client.py
Expand Up @@ -52,7 +52,8 @@ def __init__(self,
on_connected_callback=None,
on_disconnected_callback=None,
single_threaded=False,
verbose=False):
verbose=False,
writer=warnings.warn):

self._verbose = verbose
self._tcp_address = tcp_address
Expand All @@ -65,7 +66,9 @@ def __init__(self,

self._on_timeout_callback = on_timeout_callback
self._on_connected_callback = on_connected_callback
self._on_step_done = on_step_done_callback
self._on_disconnected_callback = on_disconnected_callback
self._writer = writer

if single_threaded:
self.build(single_threaded)
Expand All @@ -81,15 +84,15 @@ def open_connection(self):
raise RuntimeError('Failed to create ZMQ socket!')

if self._verbose:
warnings.warn('Connecting to server')
self._writer('Connecting to server')
if self._use_ipc_medium:
self._request_socket.connect('ipc:///tmp/neodroid/messages')
if self._verbose:
warnings.warn('Using IPC protocol')
self._writer('Using IPC protocol')
else:
self._request_socket.connect(f'tcp://{self._tcp_address}:{self._tcp_port}')
if self._verbose:
warnings.warn('Using TCP protocol')
self._writer('Using TCP protocol')

self._on_connected_callback()

Expand All @@ -101,7 +104,7 @@ def close_connection(self):
self._request_socket.setsockopt(zmq.LINGER, 0)
self._request_socket.close()
self._poller.unregister(self._request_socket)
#self._poller.close()
# self._poller.close()

def teardown(self):
self._context.term()
Expand Down Expand Up @@ -143,7 +146,7 @@ def send_reactions(self, reactions):

states, simulator_configuration = deserialise_states(flat_buffer_states)
# if LAST_RECEIVED_FRAME_NUMBER==states.frame_number:
# warnings.warn(f'Received a duplicate frame on frame number: {states.frame_number}')
# self._writer(f'Received a duplicate frame on frame number: {states.frame_number}')
# LAST_RECEIVED_FRAME_NUMBER=states.frame_number

return states, simulator_configuration
Expand All @@ -156,13 +159,15 @@ def send_reactions(self, reactions):

if retries_left <= 0:
if self._verbose:
warnings.warn('Out of retries, tearing down client')
self._writer('Out of retries, tearing down client')
self.teardown()
if self._on_disconnected_callback:
self._on_disconnected_callback()
raise ConnectionError

else:
warnings.warn(f'\nRetrying, attempt: {retries_left:d}/{REQUEST_RETRIES:d}')
self._writer(f'Retrying to connect, attempt: {retries_left:d}/{REQUEST_RETRIES:d}')
self.open_connection()
self._request_socket.send(serialised_reaction)

if self._on_step_done:
self._on_step_done()
12 changes: 6 additions & 6 deletions neodroid/models/configurable.py
Expand Up @@ -30,16 +30,16 @@ def configurable_space(self):

def to_dict(self):
return {
'_configurable_name': self._configurable_name,
'_configurable_value':self._configurable_value,
'_configurable_space':self._configurable_space
'configurable_name': self.configurable_name,
'configurable_value':self.configurable_value,
'configurable_space':self.configurable_space
}

def __repr__(self):
return (f'<Configurable>\n'
f'<configurable_name>{self._configurable_name}</configurable_name>\n'
f'<configurable_value>{self._configurable_value}</configurable_value>\n'
f'<configurable_space>{self._configurable_space}</configurable_space>\n'
f'<configurable_name>{self.configurable_name}</configurable_name>\n'
f'<configurable_value>{self.configurable_value}</configurable_value>\n'
f'<configurable_space>\n{self.configurable_space}</configurable_space>\n'
f'</Configurable>\n')

def __str__(self):
Expand Down
4 changes: 2 additions & 2 deletions neodroid/models/simulator_configuration.py
Expand Up @@ -8,7 +8,7 @@ class SimulatorConfiguration(object):
def __init__(self,
fbs_configuration,
api_version,
simulator_info='No extra info about simulator available'):
simulator_info='No simulator_info available'):
self._fbs_configuration = fbs_configuration
self._api_version = api_version
self._simulator_info = simulator_info
Expand All @@ -25,4 +25,4 @@ def api_version(self):
def simulator_info(self):
if type(self._simulator_info) is not str:
return self._simulator_info.decode()
return self._simulator_info
return self._simulator_info
9 changes: 7 additions & 2 deletions neodroid/neodroid_environments.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import draugr

__author__ = 'cnheider'

Expand Down Expand Up @@ -93,8 +94,12 @@ def __init__(
if self._verbose:
warnings.warn(f'Server is using different version {server_version}, complications may occur!')

print(f'Server API version: {server_version}, \n'
f'\t{self.simulator_configuration.simulator_info}')
print(f'Server API version: {server_version}')

draugr.sprint(f'\nconfigurable space:\n{self.description.configurables}\n',
color='blue',
bold=True,
highlight=True)

def _configure(self, *args, **kwargs):
return self._reset()
Expand Down
3 changes: 2 additions & 1 deletion neodroid/neodroid_utilities/messaging_utilities/__init__.py
Expand Up @@ -17,7 +17,8 @@ def receive_func(func):
@wraps(func)
def call_func(ctx, *args, **kwargs):
if event is ClientEvents.CONNECTED:
print('Connected to server')
pass
#print('Connected to server')
elif event is ClientEvents.DISCONNECTED:
warn('Disconnected from server', stacklevel=stack_level)
elif event is ClientEvents.TIMEOUT:
Expand Down
27 changes: 14 additions & 13 deletions neodroid/networking_environment.py
Expand Up @@ -4,7 +4,6 @@
__author__ = 'cnheider'

import time
import warnings
from abc import ABC

from tqdm import tqdm
Expand Down Expand Up @@ -51,23 +50,25 @@ def __next__(self):
def _setup_connection(self):
print(f'Connecting to server at {self._ip}:{self._port}')

self._message_server = messaging.MessageClient( self._ip,
self._port,
on_timeout_callback=self.__on_timeout_callback__,
on_connected_callback=self.__on_connected_callback__,
on_disconnected_callback=self.__on_disconnected_callback__,
verbose=self._verbose)

connect_tries = tqdm(range(CONNECT_TRY_TIMES), leave=False)

self._message_server = messaging.MessageClient(self._ip,
self._port,
on_timeout_callback=self.__on_timeout_callback__,
on_connected_callback=self.__on_connected_callback__,
on_disconnected_callback=self.__on_disconnected_callback__,
verbose=self._verbose,writer=connect_tries.write)


self._describe()

while self.description is None:
self._describe()
time.sleep(CONNECT_TRY_INTERVAL)
connect_tries.update()
connect_tries.set_description(f'Connecting, please make sure that the ip {self._ip} '
f'and port {self._port} '
f'are cd correct')
f'and port {self._port} '
f'are cd correct')
if connect_tries.n is CONNECT_TRY_TIMES:
raise ConnectionError

Expand Down Expand Up @@ -112,9 +113,9 @@ def describe(self):

def _describe(
self,
parameters=M.ReactionParameters( terminable=True,
describe=True,
episode_count=False )
parameters=M.ReactionParameters(terminable=True,
describe=True,
episode_count=False)
):
'''
Expand Down
89 changes: 45 additions & 44 deletions neodroid/wrappers/curriculum_wrapper/curriculum_wrapper.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from neodroid.models import ReactionParameters, Reaction
from neodroid.models import Reaction, ReactionParameters

__author__ = 'cnheider'

Expand All @@ -17,18 +17,17 @@ def __init__(self, **kwargs):

def __next__(self):
if not self._is_connected_to_server:
return
raise StopIteration
return self.act()

def act(self, **kwargs):
message = super().react(**kwargs)
def act(self, *args, **kwargs):
message = super().react(*args,**kwargs)
if message:
return (
np.array(flattened_observation(message)),
message.signal,
message.terminated,
message,
)
return (np.array(flattened_observation(message)),
message.signal,
message.terminated,
message,
)
return None, None, None, None

def configure(self, *args, **kwargs):
Expand All @@ -37,32 +36,34 @@ def configure(self, *args, **kwargs):
return np.array(flattened_observation(message)), message
return None, None

def generate_trajectory_from_configuration(
self,
initial_configuration,
motion_horizon=6,
non_terminable_horizon=10,
random_process=None,
):
configure_params = ReactionParameters(reset=True, configure=True
def generate_trajectory_from_configuration(self,
initial_configuration,
motion_horizon=6,
non_terminable_horizon=10,
random_process=None
):
configure_params = ReactionParameters(reset=True,
terminable=False,
configure=True
# ,episode_count=False
)
init = Reaction(
parameters=configure_params, configurations=initial_configuration
)

non_terminable_params = ReactionParameters(
step=True,
)
conf_reaction = Reaction(parameters=configure_params,
configurations=initial_configuration)

initial_states = []
non_terminable_params = ReactionParameters(step=True,
terminable=False
# ,
# episode_count=False
)

initial_states = set()
self.configure()
while len(initial_states) < 1:
state, _ = self.configure(init)
state, _ = self.configure(conf_reaction)
for i in range(non_terminable_horizon):
reaction = Reaction(
motions=self.action_space.sample(), parameters=non_terminable_params
)
state, _, terminated, info = self.act(actions=reaction)
state, _, terminated, info = self.act(self.action_space.sample(),
parameters=non_terminable_params)

for i in range(motion_horizon):
if random_process is not None:
Expand All @@ -71,18 +72,19 @@ def generate_trajectory_from_configuration(
else:
actions = self.action_space.sample()

state, _, terminated, info = self.act(actions=actions)
state, _, terminated, info = self.act(actions)

if not terminated:
initial_states.append(info)
initial_states.add(info)
non_terminable_horizon += 1

return initial_states

def generate_trajectory_from_state(
self, state, motion_horizon=10, random_process=None
):
initial_states = []
def generate_trajectory_from_state(self,
state,
motion_horizon=10,
random_process=None):
initial_states = set()
self.configure()
while len(initial_states) < 1:
s, _ = self.configure(state=state)
Expand All @@ -93,23 +95,22 @@ def generate_trajectory_from_state(
else:
actions = self.action_space.sample()

s, _, terminated, info = self.act(actions=actions)
s, _, terminated, info = self.act(actions)

if not terminated:
initial_states.append(info)
initial_states.add(info)
motion_horizon += 1

return initial_states

def observe(self, *args, **kwargs):
message = super().observe()
if message:
return (
flattened_observation(message),
message.signal,
message.terminated,
message,
)
return (flattened_observation(message),
message.signal,
message.terminated,
message,
)
return None, None, None, None

def quit(self, *args, **kwargs):
Expand Down

0 comments on commit 8db32cb

Please sign in to comment.