Skip to content

Commit

Permalink
Added option to require extra_capability_checks to exist for every ca…
Browse files Browse the repository at this point in the history
…pability
  • Loading branch information
rgerkin committed Aug 13, 2018
1 parent 57118c1 commit cca66d0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
14 changes: 10 additions & 4 deletions sciunit/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,25 @@ class Capability(SciUnit):
"""Abstract base class for sciunit capabilities."""

@classmethod
def check(cls, model):
"""Checks whether the provided model has this capability.
By default, uses isinstance.
def check(cls, model, require_extra=False):
"""Check whether the provided model has this capability.
By default, uses isinstance. If `require_extra`, also requires that an
instance check be present in `model.extra_capability_checks`.
"""
class_capable = isinstance(model, cls)
f_name = model.extra_capability_checks.get(cls, None)
if f_name:
f = getattr(model, f_name)
instance_capable = f()
else:
elif not require_extra:
instance_capable = True
else:
instance_capable = False
return class_capable and instance_capable

def unimplemented(self):
"""Raise a `NotImplementedError` with details."""
raise NotImplementedError(("The method %s promised by capability %s "
"is not implemented") %
(inspect.stack()[1][3], self.name))
Expand All @@ -43,4 +48,5 @@ class ProducesNumber(Capability):
"""An example capability for producing some generic number."""

def produce_number(self):
"""Produce a number."""
raise NotImplementedError("Must implement produce_number.")
12 changes: 8 additions & 4 deletions sciunit/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def assert_models(self, models):
"a model or iterable."))
return models

def check(self, models, skip_incapable=True, stop_on_error=True):
def check(self, models, skip_incapable=True, require_extra=False,
stop_on_error=True):
"""Like judge, but without actually running the test.
Just returns a ScoreMatrix indicating whether each model can take
Expand All @@ -87,18 +88,21 @@ def check(self, models, skip_incapable=True, stop_on_error=True):
sm = ScoreMatrix(self.tests, models)
for test in self.tests:
for model in models:
sm.loc[model, test] = test.check(model)
sm.loc[model, test] = test.check(model,
require_extra=require_extra)
return sm

def check_capabilities(self, model, skip_incapable=False):
def check_capabilities(self, model, skip_incapable=False,
require_extra=False):
"""Check model capabilities against those required by the suite.
Returns a list of booleans (one for each test in the suite)
corresponding to whether the test's required capabilities are satisfied
by the model.
"""
return [test.check_capabilities(model,
skip_incapable=skip_incapable) for test in self.tests]
skip_incapable=skip_incapable, require_extra=require_extra)
for test in self.tests]

def judge(self, models,
skip_incapable=False, stop_on_error=True, deep_error=False):
Expand Down
12 changes: 7 additions & 5 deletions sciunit/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def validate_observation(self, observation):
else:
schema = {'schema': self.observation_schema, 'type': 'dict'}
schema = {'observation': schema}
from pprint import pprint
v = ObservationValidator(schema, test=self)
if not v.validate({'observation': observation}):
raise ObservationError(v.errors)
Expand All @@ -88,24 +87,27 @@ def validate_observation(self, observation):
"""A sequence of capabilities that a model must have in order for the
test to be run. Defaults to empty."""

def check_capabilities(self, model, skip_incapable=False):
def check_capabilities(self, model, skip_incapable=False,
require_extra=False):
"""Check that test's required capabilities are implemented by `model`.
Raises an Error if model is not a Model.
Raises a CapabilityError if model does not have a capability.
"""
if not isinstance(model, Model):
raise Error("Model %s is not a sciunit.Model." % str(model))
capable = all([self.check_capability(model, c, skip_incapable)
capable = all([self.check_capability(model, c, skip_incapable,
require_extra)
for c in self.required_capabilities])
return capable

def check_capability(self, model, c, skip_incapable=False):
def check_capability(self, model, c, skip_incapable=False,
require_extra=False):
"""Check if `model` has capability `c`.
Optionally (default:True) raise a `CapabilityError` if it does not.
"""
capable = c.check(model)
capable = c.check(model, require_extra=require_extra)
if not capable and not skip_incapable:
raise CapabilityError(model, c)
return capable
Expand Down

0 comments on commit cca66d0

Please sign in to comment.