Skip to content

Commit

Permalink
Enable caching of model specific test predictions (#200)
Browse files Browse the repository at this point in the history
* add id, set_cache(), get_cache() to Test class; add use_cache decorator to store test results

* combine disk and memory cache handling in set_cache and get_cache helper functions in backend

* improve use_backend_cache decorator

* handling exception for wrong cache_key_param

* rm trailing comma in imports

* fix syntax error

* fix syntax error

* fix syntax error

* fix type error

* avoid printing model in decorator

* include overarching unittest of cache decorator

Co-authored-by: morales-gregorio <aitormorales95@gmail.com>
  • Loading branch information
rgutzen and morales-gregorio committed Jan 12, 2022
1 parent ffa380b commit fcc04b6
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 3 deletions.
37 changes: 37 additions & 0 deletions sciunit/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ def get_disk_cache(self, key: str = None) -> Any:
disk_cache.close()
return self._results

def get_cache(self, key: str = None) -> Any:
"""Return result in disk or memory cache for key 'key' or None if not
found. If both `use_disk_cache` and `use_memory_cache` are True, the
memory cache is returned.
Returns:
Any: The cache for key 'key' or None if not found.
"""
if self.use_memory_cache:
result = self.get_memory_cache(key=key)
if result is not None:
return result
if self.use_disk_cache:
result = self.get_disk_cache(key=key)
if result is not None:
return result
return None

def set_memory_cache(self, results: Any, key: str = None) -> None:
"""Store result in memory cache with key matching model state.
Expand All @@ -145,6 +163,25 @@ def set_disk_cache(self, results: Any, key: str = None) -> None:
disk_cache[key] = results
disk_cache.close()

def set_cache(self, results: Any, key: str = None) -> bool:
"""Store result in disk and/or memory cache for key 'key', depending
on whether `use_disk_cache` and `use_memory_cache` are True.
Args:
results (Any): [description]
key (str, optional): [description]. Defaults to None.
Returns:
bool: True if cache was successfully set, else False
"""
if self.use_memory_cache:
self.set_memory_cache(results, key=key)
if self.use_disk_cache:
self.set_disk_cache(results, key=key)
if self.use_memory_cache or self.use_disk_cache:
return True
return False

def load_model(self) -> None:
"""Load the model into memory."""

Expand Down
45 changes: 44 additions & 1 deletion sciunit/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import traceback
from uuid import uuid4
from copy import deepcopy
from typing import Any, List, Optional, Tuple, Union

Expand All @@ -18,7 +19,7 @@
)
from .models import Model
from .scores import BooleanScore, ErrorScore, NAScore, NoneScore, Score, TBDScore
from .utils import dict_combine
from .utils import dict_combine, use_backend_cache
from .validators import ObservationValidator, ParametersValidator


Expand All @@ -43,6 +44,8 @@ def __init__(
if self.description is None:
self.description = self.__class__.__doc__

self.id = uuid4().hex

# Use a combination of default_params and params, choosing the latter
# if there is a conflict.
self.params = dict_combine(self.default_params, params)
Expand Down Expand Up @@ -257,6 +260,7 @@ def condition_model(self, model: Model):
model (Model): A sciunit model instance.
"""

@use_backend_cache
def generate_prediction(self, model: Model) -> None:
"""Generate a prediction from a model using the required capabilities.
Expand Down Expand Up @@ -586,6 +590,45 @@ def describe(self) -> str:
result = "\n".join(s)
return result

def get_backend_cache(self, model: Model, key: Optional[str]=None) -> Any:
"""Get the cached results from the model's backend with the given key
(defaults to the id of the test instance).
Returns:
Any: The cache for key 'key' or None if not found.
"""
if model is None:
return None
if key is None:
if hasattr(self, 'id'):
key = self.id
else:
return None

if hasattr(model, 'backend') and not model.backend is None:
return model._backend.get_cache(key=key)
return None

def set_backend_cache(self, model: Model, function_output: Any,
key: Optional[str]=None) -> bool:
"""Set the cache of the model's backend with the given key (defaults to
the id of the test instance)to calculated function output.
Returns:
bool: True if cache was successfully set, else False
"""
if model is None:
return False
if key is None:
if hasattr(self, 'id'):
key = self.id
else:
return False

if hasattr(model, 'backend') and model.backend is not None:
return model._backend.set_cache(function_output, key=key)
return False

@property
def state(self) -> dict:
"""Get the frozen (pickled) model state.
Expand Down
94 changes: 94 additions & 0 deletions sciunit/unit_test/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,100 @@
import tempfile
import unittest

import sciunit
from sciunit.utils import use_backend_cache
import numpy as np


class CacheTestCase(unittest.TestCase):

def test_basic_cache(self):
class dummy_test(sciunit.Test):

# The name of the cache key param determines the cache key location
@use_backend_cache(cache_key_param='dummy_cache_key')
def create_random_matrix(self, model):
# Generate a random matrix that will land in cache
return np.random.randint(0, 100, size=(5,5))

class dummy_avg_test(dummy_test):

default_params = {'dummy_cache_key': '1234'}

@use_backend_cache
def generate_prediction(self, model):
return np.mean(self.create_random_matrix(model))

class dummy_std_test(dummy_test):

default_params = {'dummy_cache_key': '1234'}

@use_backend_cache
def generate_prediction(self, model):
return np.std(self.create_random_matrix(model))

@use_backend_cache
def warning_function(self):
return 'I am supposed to fail, because I have no model'

class dummy_backend(sciunit.models.backends.Backend):
pass

class dummy_model(sciunit.models.RunnableModel):
pass


# Initialize dummy tests and models
avg_corr_test1 = dummy_avg_test([], dummy_cache_key='1234')
avg_corr_test2 = dummy_avg_test([], dummy_cache_key='5678')
std_corr_test = dummy_std_test([])
modelA = dummy_model(name='modelA', backend=dummy_backend)
modelB = dummy_model(name='modelB', backend=dummy_backend)

# Run predictions for the first time
avg_corrsA1 = avg_corr_test1.generate_prediction(model=modelA)
avg_corrsA2 = avg_corr_test2.generate_prediction(modelA)
cached_predictionA1_avg = avg_corr_test1.get_backend_cache(model=modelA)
cached_predictionA2_avg = avg_corr_test2.get_backend_cache(model=modelA)
dummy_matrixA1 = avg_corr_test1.get_backend_cache(model=modelA,
key='1234')
dummy_matrixA2 = avg_corr_test2.get_backend_cache(model=modelA,
key='5678')
# dummy matrix is already generated
# and cached specific for modelA with key '1234'
std_corrsA = std_corr_test.generate_prediction(modelA)
cached_predictionA_std = std_corr_test.get_backend_cache(model=modelA)
dummy_matrixA_std = std_corr_test.get_backend_cache(model=modelA,
key='1234')

# Check if cached predictions are equal to original computations
self.assertTrue(avg_corrsA1 == cached_predictionA1_avg)
self.assertTrue(std_corrsA == cached_predictionA_std)

# Check that different tests yield different predictions
# These are floats, unlikely to ever be the same by chance
self.assertTrue(cached_predictionA1_avg != cached_predictionA2_avg)
self.assertTrue(cached_predictionA1_avg != cached_predictionA_std)

# Check cached matrices are the same
self.assertTrue(np.any(dummy_matrixA1 != dummy_matrixA2))
self.assertTrue(np.all(dummy_matrixA1 == dummy_matrixA_std))

"""Check that a different model will have a different chache"""
avg_corrsB = avg_corr_test1.generate_prediction(modelB)
cached_predictionB1_avg = avg_corr_test1.get_backend_cache(model=modelB)
dummy_matrixB1 = avg_corr_test1.get_backend_cache(model=modelB,
key='1234')
self.assertTrue(cached_predictionA1_avg != cached_predictionB1_avg)
self.assertTrue(np.any(dummy_matrixA1 != dummy_matrixB1))

"""Test the failing cases of the decorator"""
with self.assertWarns(Warning):
std_corr_test.warning_function()

with self.assertWarns(Warning):
test = dummy_test([])
test.create_random_matrix(modelA)

class UtilsTestCase(unittest.TestCase):
"""Unit tests for sciunit.utils"""
Expand Down
60 changes: 58 additions & 2 deletions sciunit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,11 +1026,67 @@ def decorated(*args, **kwargs):

return decorated


class_intern = intern.intern

method_memoize = memoize

def use_backend_cache(original_function=None, cache_key_param=None):
"""
Decorator for test functions (in particular `generate_prediction`) to cache
the function output on the first execution and return the output from the
cache without recomputing on any subsequent execution.
The function needs to take a model as an argument, and the caching relies on
the model's backend. If it doesn't have a backend the caching step is
skipped.
Per default, a test instance specific hash is used to link the model to the
test's function output. However, optionally, a custom hash key name can be
passed to the decorator to use the hash stored in
`self.params[<hash key name>]` instead (e.g. for using a shared cache for
redundant calculations on the same model across tests).
"""

def _decorate(function):

@functools.wraps(function)
def wrapper(self, *args, **kwargs):
sig = inspect.signature(function)
if 'model' in kwargs:
model = kwargs['model']
elif 'model' in sig.parameters.keys():
model = args[list(sig.parameters.keys()).index('model')-1]
else:
model = None
warnings.warn("The decorator `use_backend_cache` can only "
"be used for test class functions that get "
"'model' as an argument! Caching is skipped.")

cache_key = None
if cache_key_param:
if cache_key_param in self.params:
cache_key = self.params[cache_key_param]
else:
model = None
warnings.warn("The value for the decorator arguement "
"cache_key_param value can not be found in "
"self.params! Caching is skipped.")

function_output = self.get_backend_cache(model=model,
key=cache_key)

if function_output is None:
function_output = function(self, *args, **kwargs)
self.set_backend_cache(model=model,
function_output=function_output,
key=cache_key)

return function_output

return wrapper

if original_function:
return _decorate(original_function)
else:
return _decorate

def style():
"""Style a notebook with the current sciunit CSS file"""
Expand All @@ -1056,7 +1112,7 @@ def style():
display(
HTML(
"""
<style>
<style>
%s
</style>
"""
Expand Down

0 comments on commit fcc04b6

Please sign in to comment.