Skip to content

Commit

Permalink
Full discover vs. candidate processing
Browse files Browse the repository at this point in the history
Resolves #221, #78
  • Loading branch information
stscieisenhamer committed Sep 14, 2016
1 parent 4d5fa7e commit 57d64b9
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 71 deletions.
6 changes: 6 additions & 0 deletions jwst/associations/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ class Association(MutableMapping):
schema_file: str
The name of the output schema that an association
must adhere to.
registry: AssocitionRegistry
The registry this association came from.
"""

# Assume no registry
registry = None

# Default force a constraint to use first value.
DEFAULT_FORCE_UNIQUE = False

Expand Down
71 changes: 53 additions & 18 deletions jwst/associations/lib/rules_level3_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@
_REGEX_ACID_VALUE = '(o\d{3}|(c|a)\d{4})'


# Key that uniquely identfies members.
KEY = 'expname'


class DMS_Level3_Base(Association):
"""Basic class for DMS Level3 associations."""

def __init__(self, *args, **kwargs):

# Keep track of what candidates have
# gone into making this association.
self.member_observations = set()
self.member_candidates = set()
# Keep the set of members included in this association
self.members = set()

# Initialize discovered association ID
self.discovered_id = Counter(_DISCOVERED_ID_START)
Expand Down Expand Up @@ -184,12 +186,8 @@ def _add(self, member):
exposerr
))

# Document what candidates this member belonged to.
for candidate in Utility.get_candidate_list(member['ASN_CANDIDATE']):
if candidate.type == 'OBSERVATION':
self.member_observations.add(candidate)
else:
self.member_candidates.add(candidate)
# Add entry to the short list
self.members.add(entry[KEY])

def _get_target_id(self):
"""Get string representation of the target
Expand Down Expand Up @@ -295,7 +293,12 @@ class Utility(object):
"""Utility functions that understand DMS Level 3 associations"""

@staticmethod
def filter_discoverd_only(associations):
def filter_discovered_only(
associations,
discover_ruleset,
candidate_ruleset,
keep_candidates=True,
):
"""Return only those associations that have multiple candidates
Parameters
Expand All @@ -304,6 +307,15 @@ def filter_discoverd_only(associations):
The list of associations to check. The list
is that returned by the `generate` function.
discover_ruleset: str
The name of the ruleset that has the discover rules
candidate_ruleset: str
The name of the ruleset that finds just candidates
keep_candidates: bool
Keep explicit candidate associations in the list.
Returns
-------
iterable
Expand All @@ -315,13 +327,36 @@ def filter_discoverd_only(associations):
been constructed. Associations that have been Association.dump
and then Association.load will not return proper results.
"""
result = [
asn
for asn in associations
if len(asn.member_observations) > 1 and
len(asn.member_candidates) != 1
]
return result
# Split the associations along discovered/not discovered lines
asn_by_ruleset = {
candidate_ruleset: [],
discover_ruleset: []
}
for asn in associations:
asn_by_ruleset[asn.registry.name].append(asn)
candidate_list = asn_by_ruleset[candidate_ruleset]
discover_list = asn_by_ruleset[discover_ruleset]

# Filter out the non-unique discovereds.
for candidate in candidate_list:
if len(discover_list) == 0:
break
unique_list = []
for discover in discover_list:
if discover.data['asn_type'] == candidate.data['asn_type'] and \
discover.members == candidate.members:
# This association is not unique. Ignore
pass
else:
unique_list.append(discover)

# Reset the discovered list to the new unique list
# and try the next candidate.
discover_list = unique_list

if keep_candidates:
discover_list.extend(candidate_list)
return discover_list

@staticmethod
def rename_to_level2b(level1b_name):
Expand Down
19 changes: 14 additions & 5 deletions jwst/associations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# Configure logging
logger = log_config(name=__package__)

# Ruleset names
DISCOVER_RULESET = 'discover'
CANDIDATE_RULESET = 'candidate'


class Main(object):
"""Generate Associations from an Association Pool
Expand Down Expand Up @@ -153,23 +157,28 @@ def __init__(self, args=None):
self.rules = AssociationRegistry(
parsed.rules,
include_default=not parsed.ignore_default,
global_constraints=global_constraints
global_constraints=global_constraints,
name=CANDIDATE_RULESET
)

if parsed.discover:
self.rules.update(
AssociationRegistry(
parsed.rules,
include_default=not parsed.ignore_default
include_default=not parsed.ignore_default,
name=DISCOVER_RULESET
)
)

logger.info('Generating associations.')
self.associations, self.orphaned = generate(self.pool, self.rules)

if parsed.discover and not parsed.all_candidates:
self.associations = self.rules.Utility.filter_discoverd_only(
self.associations
if parsed.discover:
self.associations = self.rules.Utility.filter_discovered_only(
self.associations,
DISCOVER_RULESET,
CANDIDATE_RULESET,
keep_candidates=parsed.all_candidates,
)

logger.info(self.__str__())
Expand Down
21 changes: 13 additions & 8 deletions jwst/associations/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
expandvars,
)
import sys
from uuid import uuid4

from . import libpath
from .exceptions import (
Expand Down Expand Up @@ -50,17 +49,21 @@ class AssociationRegistry(dict):
global_constraints: dict
Constraints to be added to each rule.
name: str
An identifying string, used to prefix rule names.
"""

def __init__(self,
definition_files=None,
include_default=True,
global_constraints=None):
global_constraints=None,
name=None):
super(AssociationRegistry, self).__init__()

# Generate a UUID for this instance. Used to modify rule
# names.
self.uuid = uuid4()
self.name = name

# Precache the set of rules
self._rule_set = set()
Expand Down Expand Up @@ -90,12 +93,14 @@ def __init__(self,
for class_name, class_object in get_classes(module):
logger.debug('class_name="{}"'.format(class_name))
if class_name.startswith(USER_ASN):
rule = type(class_name, (class_object,), {})
try:
rule_name = '_'.join([self.name, class_name])
except TypeError:
rule_name = class_name
rule = type(rule_name, (class_object,), {})
rule.GLOBAL_CONSTRAINTS = global_constraints
self.__setitem__(
'_'.join([class_name, str(self.uuid)]),
rule
)
rule.registry = self
self.__setitem__(rule_name, rule)
self._rule_set.add(rule)
if class_name == 'Utility':
Utility = type('Utility', (class_object, Utility), {})
Expand Down
5 changes: 5 additions & 0 deletions jwst/associations/tests/test_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def test_read_assoc_defs_fromdefault(self):
for rule in valid_rules:
yield helpers.check_in_list, rule, rule_names

def test_registry_backref(self):
rules = AssociationRegistry()
for name, rule in rules.items():
yield helpers.check_equal, rule.registry, rules

def test_nodefs(self):
with pytest.raises(AssociationError):
rules = AssociationRegistry(include_default=False)
Expand Down
17 changes: 0 additions & 17 deletions jwst/associations/tests/test_level3_utilities.py

This file was deleted.

31 changes: 8 additions & 23 deletions jwst/associations/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def test_script(self, full_pool_rules):

generated = Main([pool_fname, '--dry-run'])
asns = generated.associations
assert len(asns) == 44
assert len(asns) == 37
found_rules = set(
asn['asn_rule']
for asn in asns
)
assert 'Asn_Image' in found_rules
assert 'Asn_WFSCMB' in found_rules
assert 'candidate_Asn_Image' in found_rules
assert 'candidate_Asn_WFSCMB' in found_rules

def test_asn_candidates(self, full_pool_rules):
pool, rules, pool_fname = full_pool_rules
Expand Down Expand Up @@ -63,25 +63,10 @@ def test_toomanyoptions(self, full_pool_rules):
'-i', 'o001',
])

def test_all_candidates(self, full_pool_rules):
def test_discovered(self, full_pool_rules):
pool, rules, pool_fname = full_pool_rules

generated = Main([pool_fname, '--dry-run', '--all-candidates'])
assert len(generated.associations) == 2

@pytest.mark.xfail()
def test_discover(self, full_pool_rules):
pool, rules, pool_fname = full_pool_rules

generated = Main([pool_fname, '--dry-run', '--discover'])
assert len(generated.associations) == 2


@pytest.mark.xfail()
def test_cross_candidate(self, full_pool_rules):
pool, rules, pool_fname = full_pool_rules

generated = Main([pool_fname, '--dry-run'])
assert len(generated.associations) == 44
generated = Main([pool_fname, '--dry-run', '--cross-candidate-only'])
assert len(generated.associations) == 5
full = Main([pool_fname, '--dry-run'])
candidates = Main([pool_fname, '--dry-run', '--all-candidates'])
discovered = Main([pool_fname, '--dry-run', '--discover'])
assert len(full.associations) == len(candidates.associations) + len(discovered.associations)

0 comments on commit 57d64b9

Please sign in to comment.