diff --git a/gptcache/adapter/adapter.py b/gptcache/adapter/adapter.py index d71b8a14..531ebb6e 100644 --- a/gptcache/adapter/adapter.py +++ b/gptcache/adapter/adapter.py @@ -16,12 +16,13 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg pre_embedding_data = chat_cache.pre_embedding_func( kwargs, extra_param=context.get("pre_embedding_func", None) ) - if cache_enable and not cache_skip: + if cache_enable: embedding_data = time_cal( chat_cache.embedding_func, func_name="embedding", report_func=chat_cache.report.embedding, )(pre_embedding_data, extra_param=context.get("embedding_func", None)) + if cache_enable and not cache_skip: cache_data_list = time_cal( chat_cache.data_manager.search, func_name="search", diff --git a/setup.py b/setup.py index 066c466e..121d6052 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ def parse_requirements(file_name: str) -> List[str]: setuptools.setup( name="gptcache", packages=find_packages(), - version="0.1.6", + version="0.1.7", author="SimFG", author_email="bang.fu@zilliz.com", description="GPT Cache, a powerful caching library that can be used to speed up and lower the cost of chat " diff --git a/tests/unit_tests/adapter/test_adapter.py b/tests/unit_tests/adapter/test_adapter.py index 27a795fe..eafa6a42 100644 --- a/tests/unit_tests/adapter/test_adapter.py +++ b/tests/unit_tests/adapter/test_adapter.py @@ -14,11 +14,6 @@ def llm_handler(*llm_args, **llm_kwargs): time.sleep(1) return a + b - def pre_embedding(data, **kwargs): - a = data.get("a", 0) - b = data.get("b", 0) - return f"{a}+{b}" - def cache_data_convert(cache_data): return int(cache_data) @@ -31,6 +26,11 @@ def add_llm(*args, **kwargs): llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs ) + def pre_embedding(data, **kwargs): + a = data.get("a", 0) + b = data.get("b", 0) + return f"{a}+{b}" + if os.path.isfile(data_map_path): os.remove(data_map_path) map_manager = get_data_manager() @@ -43,6 +43,8 @@ def add1(**kwargs): res = add_llm(a=1, b=2, **kwargs) assert res == 3, res + # pre_embedding -> embedding -> handle + # 0 + 0 + 1.0 time_cal(add1, report_func=report_func)() # test cache_skip param @@ -55,6 +57,12 @@ def delay_embedding(data, **kwargs): embedding_func=delay_embedding, data_manager=map_manager, ) + + def report_func(delta_time): + assert 1.4 < delta_time < 1.6, delta_time + + # pre_embedding -> embedding -> handle + # 0 + 0.5 + 1.0 time_cal(add1, report_func=report_func)(cache_skip=True) def report_func(delta_time):