Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ RUN pip3 install -r requirements.txt
COPY . /app/ocs/

# Install ocs
RUN pip3 install -e .
RUN pip3 install .
4 changes: 1 addition & 3 deletions agents/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from collections import defaultdict
from ocs.ocs_feed import Feed

from ocs.agent.aggregator import Provider

class RegisteredAgent:
"""
Contains data about registered agents.
Expand Down Expand Up @@ -166,7 +164,7 @@ def main(self, session: ocs_agent.OpSession, params):
field = f'{addr}_{op_name}'
field = field.replace('.', '_')
field = field.replace('-', '_')
field = Provider._enforce_field_name_rules(field)
field = Feed.enforce_field_name_rules(field)
try:
Feed.verify_data_field_string(field)
except ValueError as e:
Expand Down
63 changes: 6 additions & 57 deletions ocs/agent/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import txaio
txaio.use_twisted()

from ocs import ocs_feed
from ocs.ocs_feed import Block, Feed

from spt3g import core
import so3g
Expand Down Expand Up @@ -222,7 +222,7 @@ def _verify_provider_data(self, data):
for block_name, block_dict in data.items():
for field_name, field_values in block_dict['data'].items():
try:
ocs_feed.Feed.verify_data_field_string(field_name)
Feed.verify_data_field_string(field_name)
except ValueError:
self.log.error("data field name '{field}' is " +
"invalid, removing invalid characters.",
Expand All @@ -231,63 +231,12 @@ def _verify_provider_data(self, data):

return verified

@staticmethod
def _enforce_field_name_rules(field_name):
"""Enforce naming rules for field names.

A valid name:

* contains only letters (a-z, A-Z; case sensitive), decimal digits (0-9), and the
underscore (_).
* begins with a letter, or with any number of underscores followed by a letter.
* is at least one, but no more than 255, character(s) long.

Args:
field_name (str):
Field name string to check and modify if needed.

Returns:
str: New field name, meeting all above rules. Note this isn't
guarenteed to not collide with other field names passed
through this method, and that should be checked.

"""
# check for empty string
if field_name == "":
new_field_name = "invalid_field"
else:
new_field_name = field_name

# replace spaces with underscores
new_field_name = new_field_name.replace(' ', '_')

# replace invalid characters
new_field_name = re.sub('[^a-zA-Z0-9_]', '', new_field_name)

# grab leading underscores
underscore_search = re.compile('^_*')
underscores = underscore_search.search(new_field_name).group()

# remove leading underscores
new_field_name = re.sub('^_*', '', new_field_name)

# remove leading non-letters
new_field_name = re.sub('^[^a-zA-Z]*', '', new_field_name)

# add underscores back
new_field_name = underscores + new_field_name

# limit to 255 characters
new_field_name = new_field_name[:255]

return new_field_name

@staticmethod
def _check_for_duplicate_names(field_name, name_list):
"""Check name_list for matching field names and modify field_name if
matches are found.

The results of Provider._enforce_field_name_rules() are not guarenteed
The results of ocs_feed.Feed.enforce_field_name_rules() are not guarenteed
to be unique. This method will check field_name against a list of
existing field names and try to append '_N', with N being a zero padded
integer up to 99. Longer integers, though not expected to see use, are
Expand Down Expand Up @@ -341,11 +290,11 @@ def _rebuild_invalid_data(self, data):
new_data[block_name]['data'] = {}
new_field_names = []
for field_name, field_values in block_dict['data'].items():
new_field_name = Provider._enforce_field_name_rules(field_name)
new_field_name = Feed.enforce_field_name_rules(field_name)

# Catch instance where rule enforcement strips all characters
if not new_field_name:
new_field_name = Provider._enforce_field_name_rules("invalid_field_" + field_name)
new_field_name = Feed.enforce_field_name_rules("invalid_field_" + field_name)

new_field_name = Provider._check_for_duplicate_names(new_field_name,
new_field_names)
Expand Down Expand Up @@ -402,7 +351,7 @@ def save_to_block(self, data):
try:
b = self.blocks[key]
except KeyError:
self.blocks[key] = ocs_feed.Block(
self.blocks[key] = Block(
key, block['data'].keys(),
)
b = self.blocks[key]
Expand Down
51 changes: 51 additions & 0 deletions ocs/ocs_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,54 @@ def verify_data_field_string(field):
"exceeds the valid length of 255 characters.")

return True

@staticmethod
def enforce_field_name_rules(field_name):
"""Enforce naming rules for field names.

A valid name:

* contains only letters (a-z, A-Z; case sensitive), decimal digits (0-9), and the
underscore (_).
* begins with a letter, or with any number of underscores followed by a letter.
* is at least one, but no more than 255, character(s) long.

Args:
field_name (str):
Field name string to check and modify if needed.

Returns:
str: New field name, meeting all above rules. Note this isn't
guarenteed to not collide with other field names passed
through this method, and that should be checked.

"""
# check for empty string
if field_name == "":
new_field_name = "invalid_field"
else:
new_field_name = field_name

# replace spaces with underscores
new_field_name = new_field_name.replace(' ', '_')

# replace invalid characters
new_field_name = re.sub('[^a-zA-Z0-9_]', '', new_field_name)

# grab leading underscores
underscore_search = re.compile('^_*')
underscores = underscore_search.search(new_field_name).group()

# remove leading underscores
new_field_name = re.sub('^_*', '', new_field_name)

# remove leading non-letters
new_field_name = re.sub('^[^a-zA-Z]*', '', new_field_name)

# add underscores back
new_field_name = underscores + new_field_name

# limit to 255 characters
new_field_name = new_field_name[:255]

return new_field_name
15 changes: 2 additions & 13 deletions tests/agents/test_registry_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@

from agents.util import create_session, create_agent_fixture

from registry import Registry

try:
# depends on spt3g
from registry import Registry
agent = create_agent_fixture(Registry)

agent = create_agent_fixture(Registry)
except ModuleNotFoundError as e:
print(f"Unable to import: {e}")


@pytest.mark.spt3g
@pytest.mark.dependency(depends=['so3g'], scope='session')
class TestMain:
@pytest_twisted.inlineCallbacks
def test_registry_main(self, agent):
Expand Down Expand Up @@ -79,8 +72,6 @@ def test_registry_main_expire_agent(self, agent):
assert session.data['observatory.test_agent']['op_codes'] == expected_op_codes


@pytest.mark.spt3g
@pytest.mark.dependency(depends=['so3g'], scope='session')
class TestStopMain:
def test_registry_stop_main_while_running(self, agent):
session = create_session('main')
Expand All @@ -97,8 +88,6 @@ def test_registry_stop_main_not_running(self, agent):
assert res[0] is False


@pytest.mark.spt3g
@pytest.mark.dependency(depends=['so3g'], scope='session')
def test_registry_register_agent(agent):
session = create_session('main')
agent_data = {'agent_address': 'observatory.test_agent'}
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_registry_agent_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
client = create_client_fixture('registry')


@pytest.mark.dependency(depends=["so3g"], scope='session')
@pytest.mark.integtest
def test_registry_agent_main(wait_for_crossbar, run_agent, client):
# Startup is always true, so let's check it's running
Expand Down