diff --git a/Dockerfile b/Dockerfile index d5c870ea..49d1f0df 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,4 +34,4 @@ RUN pip3 install -r requirements.txt COPY . /app/ocs/ # Install ocs -RUN pip3 install -e . +RUN pip3 install . diff --git a/agents/registry/registry.py b/agents/registry/registry.py index 00da33c8..48f7b97b 100644 --- a/agents/registry/registry.py +++ b/agents/registry/registry.py @@ -6,8 +6,6 @@ from collections import defaultdict from ocs.ocs_feed import Feed -from ocs.agent.aggregator import Provider - class RegisteredAgent: """ Contains data about registered agents. @@ -166,7 +164,7 @@ def main(self, session: ocs_agent.OpSession, params): field = f'{addr}_{op_name}' field = field.replace('.', '_') field = field.replace('-', '_') - field = Provider._enforce_field_name_rules(field) + field = Feed.enforce_field_name_rules(field) try: Feed.verify_data_field_string(field) except ValueError as e: diff --git a/ocs/agent/aggregator.py b/ocs/agent/aggregator.py index 3b0c3159..4ac5184c 100644 --- a/ocs/agent/aggregator.py +++ b/ocs/agent/aggregator.py @@ -8,7 +8,7 @@ import txaio txaio.use_twisted() -from ocs import ocs_feed +from ocs.ocs_feed import Block, Feed from spt3g import core import so3g @@ -222,7 +222,7 @@ def _verify_provider_data(self, data): for block_name, block_dict in data.items(): for field_name, field_values in block_dict['data'].items(): try: - ocs_feed.Feed.verify_data_field_string(field_name) + Feed.verify_data_field_string(field_name) except ValueError: self.log.error("data field name '{field}' is " + "invalid, removing invalid characters.", @@ -231,63 +231,12 @@ def _verify_provider_data(self, data): return verified - @staticmethod - def _enforce_field_name_rules(field_name): - """Enforce naming rules for field names. - - A valid name: - - * contains only letters (a-z, A-Z; case sensitive), decimal digits (0-9), and the - underscore (_). - * begins with a letter, or with any number of underscores followed by a letter. - * is at least one, but no more than 255, character(s) long. - - Args: - field_name (str): - Field name string to check and modify if needed. - - Returns: - str: New field name, meeting all above rules. Note this isn't - guarenteed to not collide with other field names passed - through this method, and that should be checked. - - """ - # check for empty string - if field_name == "": - new_field_name = "invalid_field" - else: - new_field_name = field_name - - # replace spaces with underscores - new_field_name = new_field_name.replace(' ', '_') - - # replace invalid characters - new_field_name = re.sub('[^a-zA-Z0-9_]', '', new_field_name) - - # grab leading underscores - underscore_search = re.compile('^_*') - underscores = underscore_search.search(new_field_name).group() - - # remove leading underscores - new_field_name = re.sub('^_*', '', new_field_name) - - # remove leading non-letters - new_field_name = re.sub('^[^a-zA-Z]*', '', new_field_name) - - # add underscores back - new_field_name = underscores + new_field_name - - # limit to 255 characters - new_field_name = new_field_name[:255] - - return new_field_name - @staticmethod def _check_for_duplicate_names(field_name, name_list): """Check name_list for matching field names and modify field_name if matches are found. - The results of Provider._enforce_field_name_rules() are not guarenteed + The results of ocs_feed.Feed.enforce_field_name_rules() are not guarenteed to be unique. This method will check field_name against a list of existing field names and try to append '_N', with N being a zero padded integer up to 99. Longer integers, though not expected to see use, are @@ -341,11 +290,11 @@ def _rebuild_invalid_data(self, data): new_data[block_name]['data'] = {} new_field_names = [] for field_name, field_values in block_dict['data'].items(): - new_field_name = Provider._enforce_field_name_rules(field_name) + new_field_name = Feed.enforce_field_name_rules(field_name) # Catch instance where rule enforcement strips all characters if not new_field_name: - new_field_name = Provider._enforce_field_name_rules("invalid_field_" + field_name) + new_field_name = Feed.enforce_field_name_rules("invalid_field_" + field_name) new_field_name = Provider._check_for_duplicate_names(new_field_name, new_field_names) @@ -402,7 +351,7 @@ def save_to_block(self, data): try: b = self.blocks[key] except KeyError: - self.blocks[key] = ocs_feed.Block( + self.blocks[key] = Block( key, block['data'].keys(), ) b = self.blocks[key] diff --git a/ocs/ocs_feed.py b/ocs/ocs_feed.py index 7fa345a0..e6e39fac 100644 --- a/ocs/ocs_feed.py +++ b/ocs/ocs_feed.py @@ -325,3 +325,54 @@ def verify_data_field_string(field): "exceeds the valid length of 255 characters.") return True + + @staticmethod + def enforce_field_name_rules(field_name): + """Enforce naming rules for field names. + + A valid name: + + * contains only letters (a-z, A-Z; case sensitive), decimal digits (0-9), and the + underscore (_). + * begins with a letter, or with any number of underscores followed by a letter. + * is at least one, but no more than 255, character(s) long. + + Args: + field_name (str): + Field name string to check and modify if needed. + + Returns: + str: New field name, meeting all above rules. Note this isn't + guarenteed to not collide with other field names passed + through this method, and that should be checked. + + """ + # check for empty string + if field_name == "": + new_field_name = "invalid_field" + else: + new_field_name = field_name + + # replace spaces with underscores + new_field_name = new_field_name.replace(' ', '_') + + # replace invalid characters + new_field_name = re.sub('[^a-zA-Z0-9_]', '', new_field_name) + + # grab leading underscores + underscore_search = re.compile('^_*') + underscores = underscore_search.search(new_field_name).group() + + # remove leading underscores + new_field_name = re.sub('^_*', '', new_field_name) + + # remove leading non-letters + new_field_name = re.sub('^[^a-zA-Z]*', '', new_field_name) + + # add underscores back + new_field_name = underscores + new_field_name + + # limit to 255 characters + new_field_name = new_field_name[:255] + + return new_field_name diff --git a/tests/agents/test_registry_agent.py b/tests/agents/test_registry_agent.py index e174e247..13afa964 100644 --- a/tests/agents/test_registry_agent.py +++ b/tests/agents/test_registry_agent.py @@ -7,18 +7,11 @@ from agents.util import create_session, create_agent_fixture +from registry import Registry -try: - # depends on spt3g - from registry import Registry +agent = create_agent_fixture(Registry) - agent = create_agent_fixture(Registry) -except ModuleNotFoundError as e: - print(f"Unable to import: {e}") - -@pytest.mark.spt3g -@pytest.mark.dependency(depends=['so3g'], scope='session') class TestMain: @pytest_twisted.inlineCallbacks def test_registry_main(self, agent): @@ -79,8 +72,6 @@ def test_registry_main_expire_agent(self, agent): assert session.data['observatory.test_agent']['op_codes'] == expected_op_codes -@pytest.mark.spt3g -@pytest.mark.dependency(depends=['so3g'], scope='session') class TestStopMain: def test_registry_stop_main_while_running(self, agent): session = create_session('main') @@ -97,8 +88,6 @@ def test_registry_stop_main_not_running(self, agent): assert res[0] is False -@pytest.mark.spt3g -@pytest.mark.dependency(depends=['so3g'], scope='session') def test_registry_register_agent(agent): session = create_session('main') agent_data = {'agent_address': 'observatory.test_agent'} diff --git a/tests/integration/test_registry_agent_integration.py b/tests/integration/test_registry_agent_integration.py index 4273b5ae..d4638ed7 100644 --- a/tests/integration/test_registry_agent_integration.py +++ b/tests/integration/test_registry_agent_integration.py @@ -20,7 +20,6 @@ client = create_client_fixture('registry') -@pytest.mark.dependency(depends=["so3g"], scope='session') @pytest.mark.integtest def test_registry_agent_main(wait_for_crossbar, run_agent, client): # Startup is always true, so let's check it's running