-
Notifications
You must be signed in to change notification settings - Fork 723
/
Copy pathtest_experience.py
54 lines (44 loc) · 1.75 KB
/
test_experience.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import pytest
from injector import Injector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory.experience import ExperienceGenerator
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_experience_retrieval():
app_injector = Injector([LoggingModule])
app_config = AppConfigSource(
config_file_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
"project/taskweaver_config.json",
),
config={
"llm.embedding_api_type": "sentence_transformers",
"llm.embedding_model": "all-mpnet-base-v2",
"experience.refresh_experience": False,
"experience.retrieve_threshold": 0.0,
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
experience_manager = app_injector.create_object(ExperienceGenerator)
experience_manager.set_experience_dir(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data/experience",
),
)
user_query = "show top 10 data in ./data.csv"
experience_manager.refresh()
experience_manager.load_experience()
assert len(experience_manager.experience_list) == 1
exp = experience_manager.experience_list[0]
assert len(exp.experience_text) > 0
assert exp.exp_id == "test-exp-1"
assert len(exp.embedding) == 768
assert exp.embedding_model == "all-mpnet-base-v2"
experiences = experience_manager.retrieve_experience(user_query=user_query)
assert len(experiences) == 1
assert experiences[0][0].exp_id == "test-exp-1"