From 401593389f239193a4f9c0a993361165fc0419b7 Mon Sep 17 00:00:00 2001 From: Amir Mehr Date: Thu, 7 Mar 2024 18:14:20 -0700 Subject: [PATCH] fix: cache for the baleen test and add notebook cache to everywhere. --- dsp/modules/cache_utils.py | 3 ++- tests/examples/test_baleen.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dsp/modules/cache_utils.py b/dsp/modules/cache_utils.py index 78270c879a..06e4a2c6e1 100644 --- a/dsp/modules/cache_utils.py +++ b/dsp/modules/cache_utils.py @@ -27,7 +27,8 @@ def wrapper(*args, **kwargs): cachedir = os.environ.get('DSP_CACHEDIR') or os.path.join(Path.home(), 'cachedir_joblib') CacheMemory = Memory(location=cachedir, verbose=0) -cachedir2 = os.environ.get('DSP_NOTEBOOK_CACHEDIR') +project_home = Path(__file__).resolve().parent.parent.parent +cachedir2 = os.environ.get('DSP_NOTEBOOK_CACHEDIR') or os.path.join(project_home, 'cache') NotebookCacheMemory = dotdict() NotebookCacheMemory.cache = noop_decorator diff --git a/tests/examples/test_baleen.py b/tests/examples/test_baleen.py index ab14458444..c1f8412d80 100644 --- a/tests/examples/test_baleen.py +++ b/tests/examples/test_baleen.py @@ -1,4 +1,6 @@ import pytest +import os +from dsp.modules.cache_utils import * from dsp.utils import deduplicate import dspy.evaluate import dspy @@ -87,8 +89,8 @@ def validate_context_and_answer_and_hops(example, pred, trace=None): if max([len(h) for h in hops]) > 100: return False if any( - dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) - for idx in range(2, len(hops)) + dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8) + for idx in range(2, len(hops)) ): return False @@ -106,7 +108,7 @@ def gold_passages_retrieved(example, pred, trace=None): # @pytest.mark.slow_test # TODO: Find a way to make this test run without the slow hotpotqa dataset -def _test_compiled_baleen(): +def test_compiled_baleen(): trainset, devset = load_hotpotqa() lm = dspy.OpenAI(model="gpt-3.5-turbo") rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") @@ -133,4 +135,4 @@ def _test_compiled_baleen(): compiled_baleen, metric=gold_passages_retrieved ) # assert compiled_baleen_retrieval_score / 100 == 27 / 50 - assert uncompiled_baleen_retrieval_score < compiled_baleen_retrieval_score \ No newline at end of file + assert uncompiled_baleen_retrieval_score < compiled_baleen_retrieval_score