Skip to content

Commit

Permalink
Updated for new API
Browse files Browse the repository at this point in the history
  • Loading branch information
rgerkin committed Jun 21, 2021
1 parent 1396f6c commit f70f2ba
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 22 deletions.
3 changes: 2 additions & 1 deletion sciunit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def __init__(self):
#: A class attribute containing a list of other attributes to be hidden
# from state calculations
state_hide = ['hash', 'pickling', 'temp_dir']


def __getstate__(self) -> dict:
"""Copy the object's state from self.__dict__.
Expand Down Expand Up @@ -354,7 +355,7 @@ def property_names(self) -> list:
# return self._state()

def json(
self, add_props: bool = None, string: bool = None, unpicklable: bool = None, make_refs: bool = None) -> str:
self, add_props: bool = True, string: bool = True, unpicklable: bool = False, make_refs: bool = False) -> str:
"""Generate a Json format encoded sciunit instance.
Args:
Expand Down
1 change: 0 additions & 1 deletion sciunit/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def init_backend(self, *args, **kwargs) -> None:
if self.use_disk_cache:
self.init_disk_cache()
self.load_model()
self.model.unpicklable += ["_backend"]

#: Name of the backend
name = None
Expand Down
11 changes: 6 additions & 5 deletions sciunit/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def capabilities(self) -> list:
def failed_extra_capabilities(self) -> list:
"""Check to see if instance passes its `extra_capability_checks`."""
failed = []
for capability, f_name in self.extra_capability_checks.items():
f = getattr(self, f_name)
instance_capable = f()
if isinstance(self, capability) and not instance_capable:
failed.append(capability)
if self.extra_capability_checks is not None:
for capability, f_name in self.extra_capability_checks.items():
f = getattr(self, f_name)
instance_capable = f()
if isinstance(self, capability) and not instance_capable:
failed.append(capability)
return failed

def describe(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion sciunit/models/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
self.run_params = {} # Should be reset between tests
self.print_run_params = False # Print the run parameters with each run
self.default_run_params = {} # Should be applied to all tests
self.unpicklable = [] # Model attributes which cannot be pickled
if attrs and not isinstance(attrs, dict):
raise TypeError("Model 'attrs' must be a dictionary.")
self.attrs = attrs if attrs else {}
Expand Down
38 changes: 28 additions & 10 deletions sciunit/scores/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ class ScoreArray(pd.Series, SciUnit, TestWeighted):
(score_1, ..., score_n)
"""

def __init__(self, tests_or_models, scores=None, weights=None):
def __init__(self, tests_or_models, scores=None, weights=None, name=None):
if scores is None:
scores = [NoneScore for tom in tests_or_models]
tests_or_models = self.check_tests_and_models(tests_or_models)
self.weights_ = [] if not weights else list(weights)
super(ScoreArray, self).__init__(data=scores, index=tests_or_models)
name = (name or self.__class__.__name__)
self._name = name # Necessary for some reason even though
# it is also passed to pd.Series constructor
super(ScoreArray, self).__init__(data=scores, index=tests_or_models,
name=name)
self.index_type = "tests" if isinstance(tests_or_models[0], Test) else "models"
setattr(self, self.index_type, tests_or_models)

Expand Down Expand Up @@ -81,13 +85,16 @@ def get_by_name(self, name: str) -> Union[Model, Test]:
if item is None:
raise KeyError("No model or test with name '%s'" % name)
return item

def __setattr__(self, attr, value):
self.__dict__[attr] = value

def __getattr__(self, name):
if name in self.direct_attrs:
attr = self.apply(lambda x: getattr(x, name))
else:
attr = super(ScoreArray, self).__getattribute__(name)
return attr
#def __getattr__(self, name):
# if name in self.direct_attrs:
# attr = self.apply(lambda x: getattr(x, name))
# else:
# attr = super(ScoreArray, self).__getattribute__(name)
# return attr

@property
def related_data(self) -> pd.Series:
Expand All @@ -101,6 +108,8 @@ def scores_flat(self) -> list:
def scores(self) -> pd.Series:
return self.map(lambda x: x.score)

score = scores # Backwards compatibility

@property
def norm_scores(self) -> pd.Series:
"""Return the `norm_score` for each test.
Expand Down Expand Up @@ -234,6 +243,7 @@ def get_test(self, test: Test) -> ScoreArray:
self.models,
scores=super(ScoreMatrix, self).__getitem__(test),
weights=self.weights,
name=test.name,
)

def get_model(self, model: Model) -> ScoreArray:
Expand All @@ -245,7 +255,10 @@ def get_model(self, model: Model) -> ScoreArray:
Returns:
ScoreArray: The generated ScoreArray instance.
"""
return ScoreArray(self.tests, scores=self.loc[model, :], weights=self.weights)
return ScoreArray(self.tests,
scores=self.loc[model, :],
weights=self.weights,
name=model.name)

def get_group(self, x: tuple) -> Union[Model, Test, Score]:
"""[summary]
Expand Down Expand Up @@ -289,12 +302,15 @@ def get_by_name(self, name: str) -> Union[Model, Test]:
if test.name == name:
return self.__getitem__(test)
raise KeyError("No model or test with name '%s'" % name)

def __setattr__(self, attr, value):
self.__dict__[attr] = value

#def __getattr__(self, name):
# if name in self.direct_attrs:
# attr = self.applymap(lambda x: getattr(x, name))
# else:
# attr = super(ScoreMatrix, self).__getattribute__(name)
# attr = super(ScoreMatrix, self).__getattribute__(name)
# return attr

@property
Expand All @@ -308,6 +324,8 @@ def scores_flat(self) -> list:
@property
def scores(self) -> pd.DataFrame:
return self.applymap(lambda x: x.score)

score = scores # Backwards compatibility

@property
def norm_scores(self) -> pd.DataFrame:
Expand Down
4 changes: 1 addition & 3 deletions sciunit/unit_test/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ def test_SciUnit(self):
from sciunit.base import SciUnit

sciunitObj = SciUnit()
self.assertIsInstance(sciunitObj.properties, dict)
self.assertIsInstance(sciunitObj.properties(), dict)
self.assertIsInstance(sciunitObj.__getstate__(), dict)
self.assertIsInstance(sciunitObj.json(), str)
self.assertIsInstance(sciunitObj._id, int)
self.assertIsInstance(sciunitObj.id, str)
sciunitObj.json(string=False)
self.assertIsInstance(sciunitObj._class, dict)
sciunitObj.testState = "testState"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ install_requires =
nbconvert
nbformat
pandas>=0.18
quantities>=0.12.1
quantities>=0.12.4.1 @ https://github.com/scidash/python-quantities/archive/master.tar.gz


[options.entry_points]
Expand Down

0 comments on commit f70f2ba

Please sign in to comment.