-
Notifications
You must be signed in to change notification settings - Fork 723
/
Copy pathtest_round_compressor.py
79 lines (68 loc) · 2.09 KB
/
test_round_compressor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from injector import Injector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory import RoundCompressor
def test_round_compressor():
from taskweaver.memory import Post, Round
app_injector = Injector(
[LoggingModule],
)
app_config = AppConfigSource(
config={
"llm.api_key": "test_key",
"round_compressor.rounds_to_compress": 2,
"round_compressor.rounds_to_retain": 2,
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
compressor = app_injector.get(RoundCompressor)
assert compressor.rounds_to_compress == 2
assert compressor.rounds_to_retain == 2
round1 = Round.create(user_query="hello", id="round-1")
post1 = Post.create(
message="hello",
send_from="User",
send_to="Planner",
attachment_list=[],
)
post2 = Post.create(
message="hello",
send_from="Planner",
send_to="User",
attachment_list=[],
)
round1.add_post(post1)
round1.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1],
lambda x: x,
)
assert summary == "None"
assert len(retained) == 1
round2 = Round.create(user_query="hello", id="round-2")
round2.add_post(post1)
round2.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1, round2],
lambda x: x,
)
assert summary == "None"
assert len(retained) == 2
round3 = Round.create(user_query="hello", id="round-3")
round3.add_post(post1)
round3.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1, round2, round3],
lambda x: x,
)
assert summary == "None"
assert len(retained) == 3
round4 = Round.create(user_query="hello", id="round-4")
round4.add_post(post1)
round4.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1, round2, round3, round4],
lambda x: x,
)
assert summary == "None"
assert len(retained) == 4