Skip to content

Commit

Permalink
Fix the cache_skip param don't take effect (#168)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed Apr 10, 2023
1 parent 9a2d02e commit 4f74e5f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
3 changes: 2 additions & 1 deletion gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
pre_embedding_data = chat_cache.pre_embedding_func(
kwargs, extra_param=context.get("pre_embedding_func", None)
)
if cache_enable and not cache_skip:
if cache_enable:
embedding_data = time_cal(
chat_cache.embedding_func,
func_name="embedding",
report_func=chat_cache.report.embedding,
)(pre_embedding_data, extra_param=context.get("embedding_func", None))
if cache_enable and not cache_skip:
cache_data_list = time_cal(
chat_cache.data_manager.search,
func_name="search",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_requirements(file_name: str) -> List[str]:
setuptools.setup(
name="gptcache",
packages=find_packages(),
version="0.1.6",
version="0.1.7",
author="SimFG",
author_email="bang.fu@zilliz.com",
description="GPT Cache, a powerful caching library that can be used to speed up and lower the cost of chat "
Expand Down
18 changes: 13 additions & 5 deletions tests/unit_tests/adapter/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ def llm_handler(*llm_args, **llm_kwargs):
time.sleep(1)
return a + b

def pre_embedding(data, **kwargs):
a = data.get("a", 0)
b = data.get("b", 0)
return f"{a}+{b}"

def cache_data_convert(cache_data):
return int(cache_data)

Expand All @@ -31,6 +26,11 @@ def add_llm(*args, **kwargs):
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)

def pre_embedding(data, **kwargs):
a = data.get("a", 0)
b = data.get("b", 0)
return f"{a}+{b}"

if os.path.isfile(data_map_path):
os.remove(data_map_path)
map_manager = get_data_manager()
Expand All @@ -43,6 +43,8 @@ def add1(**kwargs):
res = add_llm(a=1, b=2, **kwargs)
assert res == 3, res

# pre_embedding -> embedding -> handle
# 0 + 0 + 1.0
time_cal(add1, report_func=report_func)()

# test cache_skip param
Expand All @@ -55,6 +57,12 @@ def delay_embedding(data, **kwargs):
embedding_func=delay_embedding,
data_manager=map_manager,
)

def report_func(delta_time):
assert 1.4 < delta_time < 1.6, delta_time

# pre_embedding -> embedding -> handle
# 0 + 0.5 + 1.0
time_cal(add1, report_func=report_func)(cache_skip=True)

def report_func(delta_time):
Expand Down

0 comments on commit 4f74e5f

Please sign in to comment.