Skip to content

Commit

Permalink
Merge pull request #146 from ChihweiLHBird/unit_tests
Browse files Browse the repository at this point in the history
Unit tests
  • Loading branch information
rgerkin committed Jun 4, 2020
2 parents 68e827b + bd827ca commit 0bf6f2e
Show file tree
Hide file tree
Showing 25 changed files with 817 additions and 267 deletions.
13 changes: 4 additions & 9 deletions sciunit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,12 @@
import sciunit
from pathlib import Path
from typing import Union
try:
import configparser
except ImportError:
import ConfigParser as configparser
import configparser

import codecs
try:
import matplotlib
matplotlib.use('Agg') #: Anticipate possible headless environments
except ImportError:
pass

import matplotlib
matplotlib.use('Agg') #: Anticipate possible headless environments

NB_VERSION = 4

Expand Down
39 changes: 19 additions & 20 deletions sciunit/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
"""The base class for many SciUnit objects."""



import os
import sys
import json
import pickle
import hashlib

PLATFORM = sys.platform
PYTHON_MAJOR_VERSION = sys.version_info.major
if PYTHON_MAJOR_VERSION < 3: # Python 2
raise Exception('Only Python 3 is supported')

import json, git, pickle, hashlib

import numpy as np
import pandas as pd
import git

from git.exc import GitCommandError, InvalidGitRepositoryError
from git.cmd import Git
from git.remote import Remote
from git.repo.base import Repo
from typing import Dict, List, Optional, Tuple, Union, Any

PYTHON_MAJOR_VERSION = sys.version_info.major
PLATFORM = sys.platform

if PYTHON_MAJOR_VERSION < 3: # Python 2
raise Exception('Only Python 3 is supported')

from io import StringIO
try:
import tkinter
Expand Down Expand Up @@ -174,13 +173,13 @@ def __getstate__(self) -> dict:
del state[key]
return state

def _state(self, state: dict=None, keys: dict=None,
def _state(self, state: dict=None, keys: list=None,
exclude: List[str]=None) -> dict:
"""Get the state of the instance.
Args:
state (dict, optional): [description]. Defaults to None.
keys (dict, optional): [description]. Defaults to None.
keys (list, optional): [description]. Defaults to None.
exclude (List[str], optional): [description]. Defaults to None.
Returns:
Expand Down Expand Up @@ -274,7 +273,7 @@ def hash(self) -> str:
return self.dict_hash(self.state)

def json(self, add_props: bool=False, keys: list=None, exclude: list=None, string: bool=True,
indent: None=None) -> str:
indent: None=None) -> Any:
"""[summary]
Args:
Expand All @@ -295,7 +294,7 @@ def json(self, add_props: bool=False, keys: list=None, exclude: list=None, strin
return result

@property
def _id(self) -> str:
def _id(self) -> Any:
return id(self)

@property
Expand Down Expand Up @@ -330,17 +329,17 @@ def __init__(self, *args, **kwargs):
kwargs.pop(key)
super(SciUnitEncoder, self).__init__(*args, **kwargs)

def default(self, obj: Any) -> dict:
"""[summary]
def default(self, obj: Any) -> Union[str, dict, list]:
"""Try to encode the object.
Args:
obj (Any): [description]
obj (Any): Any object to be encoded
Raises:
e: Could not JSON encode the object.
e: Could not JSON serialize the object.
Returns:
dict: [description]
Union[str, dict, list]: Encoded object.
"""
try:
if isinstance(obj, pd.DataFrame):
Expand Down
4 changes: 4 additions & 0 deletions sciunit/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def get_memory_cache(self, key: str=None) -> dict:
dict: The memory cache for key 'key' or None if not found.
"""
key = self.model.hash if key is None else key
if not getattr(self, 'memory_cache', False):
self.init_memory_cache()
self._results = self.memory_cache.get(key)
return self._results

Expand Down Expand Up @@ -113,6 +115,8 @@ def set_memory_cache(self, results: Any, key: str=None) -> None:
key (str, optional): [description]. Defaults to None.
"""
key = self.model.hash if key is None else key
if not getattr(self, 'memory_cache', False):
self.init_memory_cache()
self.memory_cache[key] = results

def set_disk_cache(self, results: Any, key: str=None) -> None:
Expand Down
2 changes: 0 additions & 2 deletions sciunit/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ def __init__(self, name=None, **params):
name = self.__class__.__name__
self.name = name
self.params = params
if params is None:
params = {}
super(Model, self).__init__()
self.check_params()

Expand Down
6 changes: 3 additions & 3 deletions sciunit/scores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from sciunit.utils import log, config_get
from sciunit.errors import InvalidScoreError
from typing import Union, Tuple

from quantities import Quantity
class Score(SciUnit):
"""Abstract base class for scores."""

def __init__(self, score: 'Score', related_data: dict=None):
def __init__(self, score: Union['Score', float, int, Quantity], related_data: dict=None):
"""Abstract base class for scores.
Args:
score (int, float, bool): A raw value to wrap in a Score class.
score (Union['Score', float, int, Quantity], bool): A raw value to wrap in a Score class.
related_data (dict, optional): Artifacts to store with the score.
"""
self.check_score(score)
Expand Down
2 changes: 1 addition & 1 deletion sciunit/scores/complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def compute(cls, observation: dict, prediction: dict) -> 'ZScore':
error = ("Observation must have keys 'mean' and 'std' "
"when using ZScore")
return InsufficientDataScore(error)
if not o_std > 0:
if o_std <= 0:
error = 'Observation standard deviation must be > 0'
return InsufficientDataScore(error)
value = (p_value - o_mean)/o_std
Expand Down
4 changes: 2 additions & 2 deletions sciunit/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TestSuite(SciUnit, TestWeighted):

def __init__(self, tests: List[Test], name: str=None, weights=None, include_models: List[Model]=None,
skip_models: List[Model]=None, hooks: dict=None,
optimizer: None=None):
optimizer=None):
"""optimizer: a function to bind to self.optimize (first argument must be a testsuite).
Args:
Expand All @@ -30,7 +30,7 @@ def __init__(self, tests: List[Test], name: str=None, weights=None, include_mode
include_models (List[Model], optional): The list of models. Defaults to None.
skip_models (List[Model], optional): [description]. Defaults to None.
hooks (dict, optional): [description]. Defaults to None.
optimizer (None, optional): [description]. Defaults to None.
optimizer (optional): [description]. Defaults to None.
"""

self.name = name if name else "Suite_%d" % random.randint(0, 1e12)
Expand Down
2 changes: 1 addition & 1 deletion sciunit/unit_test/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from .backend_tests import *
from .base_tests import *
from .command_line_tests import *
from .config_tests import *
from .converter_tests import *
Expand All @@ -17,4 +18,3 @@
from .test_tests import *
from .utils_tests import *
from .validator_tests import *
from .base_tests import *
102 changes: 99 additions & 3 deletions sciunit/unit_test/backend_tests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,110 @@
"""Unit tests for backends."""

from sciunit.models.backends import Backend
from sciunit.utils import NotebookTools
import unittest
from sciunit import Model
import unittest, pathlib


class BackendsTestCase(unittest.TestCase, NotebookTools):
"""Unit tests for the sciunit module"""

path = '.'
path = "."

def test_backends(self):
"""Test backends."""
self.do_notebook('backend_tests')
self.do_notebook("backend_tests")

def test_backends_init_caches(self):
myModel = Model()
backend = Backend()
backend.model = myModel

backend.init_backend(use_disk_cache=True, use_memory_cache=True)
backend.init_backend(use_disk_cache=False, use_memory_cache=True)
backend.init_backend(use_disk_cache=True, use_memory_cache=False)
backend.init_backend(use_disk_cache=False, use_memory_cache=False)
backend.init_cache()

def test_backends_set_caches(self):
myModel = Model()
backend = Backend()
backend.model = myModel
# backend.init_memory_cache()
self.assertIsNone(backend.get_disk_cache("key1"))
self.assertIsNone(backend.get_disk_cache("key2"))
self.assertIsNone(backend.get_memory_cache("key1"))
self.assertIsNone(backend.get_memory_cache("key2"))
backend.set_disk_cache("value1", "key1")
backend.set_memory_cache("value1", "key1")
self.assertEqual(backend.get_memory_cache("key1"), "value1")
self.assertEqual(backend.get_disk_cache("key1"), "value1")
backend.set_disk_cache("value2")
backend.set_memory_cache("value2")
self.assertEqual(backend.get_memory_cache(myModel.hash), "value2")
self.assertEqual(backend.get_disk_cache(myModel.hash), "value2")

backend.load_model()
backend.set_attrs(test_attribute="test attribute")
backend.set_run_params(test_param="test parameter")
backend.init_backend(use_disk_cache=True, use_memory_cache=True)

def test_backend_run(self):
backend = Backend()
self.assertRaises(NotImplementedError, backend._backend_run)

class MyBackend(Backend):
model = Model()

def _backend_run(self) -> str:
return "test result"

backend = MyBackend()
backend.init_backend(use_disk_cache=True, use_memory_cache=True)
backend.backend_run()
backend.set_disk_cache("value1", "key1")
backend.set_memory_cache("value1", "key1")
backend.backend_run()
backend.set_disk_cache("value2")
backend.set_memory_cache("value2")
backend.backend_run()
# backend.save_results(pathlib.Path().absolute())

backend = MyBackend()
backend.init_backend(use_disk_cache=False, use_memory_cache=True)
backend.backend_run()
backend.set_disk_cache("value1", "key1")
backend.set_memory_cache("value1", "key1")
backend.backend_run()
backend.set_disk_cache("value2")
backend.set_memory_cache("value2")
backend.backend_run()
# backend.save_results(pathlib.Path().absolute())

backend = MyBackend()
backend.init_backend(use_disk_cache=True, use_memory_cache=False)
# backend.init_memory_cache()
backend.backend_run()
backend.set_disk_cache("value1", "key1")
backend.set_memory_cache("value1", "key1")
backend.backend_run()
backend.set_disk_cache("value2")
backend.set_memory_cache("value2")
backend.backend_run()
# backend.save_results(pathlib.Path().absolute())

backend = MyBackend()
backend.init_backend(use_disk_cache=False, use_memory_cache=False)
# backend.init_memory_cache()
backend.backend_run()
backend.set_disk_cache("value1", "key1")
backend.set_memory_cache("value1", "key1")
backend.backend_run()
backend.set_disk_cache("value2")
backend.set_memory_cache("value2")
backend.backend_run()
# backend.save_results(pathlib.Path().absolute())


if __name__ == "__main__":
unittest.main()
11 changes: 7 additions & 4 deletions sciunit/unit_test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
import sys
import unittest

OSX = sys.platform == 'darwin'
if OSX or 'Qt' in mpl.rcParams['backend']:
mpl.use('Agg') # Avoid any problems with Macs or headless displays.
OSX = sys.platform == "darwin"
if OSX or "Qt" in mpl.rcParams["backend"]:
mpl.use("Agg") # Avoid any problems with Macs or headless displays.


class SuiteBase(object):
"""Abstract base class for testing suites and scores"""

def setUp(self):
from sciunit.models.examples import UniformModel
from sciunit.tests import RangeTest

self.M = UniformModel
self.T = RangeTest

def prep_models_and_tests(self):
from sciunit import TestSuite
t1 = self.T([2, 3], name='test1')

t1 = self.T([2, 3], name="test1")
t2 = self.T([5, 6])
m1 = self.M(2, 3)
m2 = self.M(5, 6)
Expand Down

0 comments on commit 0bf6f2e

Please sign in to comment.