diff --git a/tests/test_defaults.py b/tests/test_defaults.py index 96a14f8b..dc87a4ae 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -2,7 +2,6 @@ import os import queue import random -import tempfile import threading import time @@ -81,46 +80,42 @@ def global_test_2(): assert global_test_2.cache_dpath() is not None -def test_cache_dir_default_param(): - cachier.set_default_params(cache_dir="/path_1") +def test_cache_dir_default_param(tmpdir): + cachier.set_default_params(cache_dir=tmpdir / "1") @cachier.cachier() def global_test_1(): return None - @cachier.cachier(cache_dir="/path_2") + @cachier.cachier(cache_dir=tmpdir / "2") def global_test_2(): return None - assert global_test_1.cache_dpath() == "/path_1" - assert global_test_2.cache_dpath() == "/path_2" + assert global_test_1.cache_dpath() == str(tmpdir / "1") + assert global_test_2.cache_dpath() == str(tmpdir / "2") -def test_separate_files_default_param(): +def test_separate_files_default_param(tmpdir): cachier.set_default_params(separate_files=True) - @cachier.cachier(cache_dir=tempfile.mkdtemp()) + @cachier.cachier(cache_dir=tmpdir / "1") def global_test_1(arg_1, arg_2): return arg_1 + arg_2 - @cachier.cachier(cache_dir=tempfile.mkdtemp(), separate_files=False) + @cachier.cachier(cache_dir=tmpdir / "2", separate_files=False) def global_test_2(arg_1, arg_2): return arg_1 + arg_2 - global_test_1.clear_cache() global_test_1(1, 2) global_test_1(3, 4) - global_test_2.clear_cache() global_test_2(1, 2) global_test_2(3, 4) - cache_dir_1 = global_test_1.cache_dpath() - cache_dir_2 = global_test_2.cache_dpath() - assert len(os.listdir(cache_dir_1)) == 2 - assert len(os.listdir(cache_dir_2)) == 1 + assert len(os.listdir(global_test_1.cache_dpath())) == 2 + assert len(os.listdir(global_test_2.cache_dpath())) == 1 -def test_allow_none_default_param(): +def test_allow_none_default_param(tmpdir): cachier.set_default_params( allow_none=True, separate_files=True, @@ -128,13 +123,13 @@ def test_allow_none_default_param(): ) allow_count = disallow_count = 0 - @cachier.cachier(cache_dir=tempfile.mkdtemp()) + @cachier.cachier(cache_dir=tmpdir) def allow_none(): nonlocal allow_count allow_count += 1 return None - @cachier.cachier(cache_dir=tempfile.mkdtemp(), allow_none=False) + @cachier.cachier(cache_dir=tmpdir, allow_none=False) def disallow_none(): nonlocal disallow_count disallow_count += 1 @@ -155,14 +150,16 @@ def disallow_none(): assert disallow_count == 2 -parametrize_keys = "backend,mongetter" -parametrize_values = [ - pytest.param("pickle", None, marks=pytest.mark.pickle), - pytest.param("mongo", _test_mongetter, marks=pytest.mark.mongo), -] +PARAMETRIZE_TEST = ( + "backend,mongetter", + [ + pytest.param("pickle", None, marks=pytest.mark.pickle), + pytest.param("mongo", _test_mongetter, marks=pytest.mark.mongo), + ], +) -@pytest.mark.parametrize(parametrize_keys, parametrize_values) +@pytest.mark.parametrize(*PARAMETRIZE_TEST) def test_stale_after_applies_dynamically(backend, mongetter): @cachier.cachier(backend=backend, mongetter=mongetter) def _stale_after_test(arg_1, arg_2): @@ -180,7 +177,7 @@ def _stale_after_test(arg_1, arg_2): assert val3 != val1 -@pytest.mark.parametrize(parametrize_keys, parametrize_values) +@pytest.mark.parametrize(*PARAMETRIZE_TEST) def test_next_time_applies_dynamically(backend, mongetter): NEXT_AFTER_DELTA = datetime.timedelta(seconds=3) @@ -206,8 +203,10 @@ def _stale_after_next_time(arg_1, arg_2): _stale_after_next_time.clear_cache() -@pytest.mark.parametrize(parametrize_keys, parametrize_values) +@pytest.mark.parametrize(*PARAMETRIZE_TEST) def test_wait_for_calc_applies_dynamically(backend, mongetter): + """Testing for calls timing out to be performed twice when needed.""" + @cachier.cachier(backend=backend, mongetter=mongetter) def _wait_for_calc_timeout_slow(arg_1, arg_2): time.sleep(3) @@ -218,7 +217,6 @@ def _calls_wait_for_calc_timeout_slow(res_queue): res_queue.put(res) cachier.set_default_params(wait_for_calc_timeout=2) - """Testing for calls timing out to be performed twice when needed.""" _wait_for_calc_timeout_slow.clear_cache() res_queue = queue.Queue() thread1 = threading.Thread(