Skip to content

Commit ad99dc5

Browse files
jameszyaoSimsonW
authored andcommitted
test: update test script
1 parent a8bb056 commit ad99dc5

File tree

8 files changed

+301
-309
lines changed

8 files changed

+301
-309
lines changed

test/testcase/test_async/test_async_assistant.py

Lines changed: 125 additions & 118 deletions
Large diffs are not rendered by default.

test/testcase/test_async/test_async_inference.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,22 @@
33
from taskingai.inference import *
44
from test.config import embedding_model_id, chat_completion_model_id
55
from test.common.logger import logger
6+
import re
67

78

89
@pytest.mark.test_async
910
class TestChatCompletion:
10-
1111
@pytest.mark.run(order=4)
1212
@pytest.mark.asyncio
1313
async def test_a_chat_completion(self):
14-
1514
# normal chat completion.
1615

1716
normal_res = await a_chat_completion(
1817
model_id=chat_completion_model_id,
1918
messages=[
2019
SystemMessage("You are a professional assistant."),
2120
UserMessage("Hi"),
22-
]
21+
],
2322
)
2423
pytest.assume(normal_res.finish_reason == "stop")
2524
pytest.assume(normal_res.message.content)
@@ -35,10 +34,7 @@ async def test_a_chat_completion(self):
3534
UserMessage("Hi"),
3635
AssistantMessage("Hello! How can I assist you today?"),
3736
UserMessage("Can you tell me a joke?"),
38-
AssistantMessage(
39-
"Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"),
40-
UserMessage("That's funny. Can you tell me another one?"),
41-
]
37+
],
4238
)
4339

4440
pytest.assume(multi_round_res.finish_reason == "stop")
@@ -55,13 +51,8 @@ async def test_a_chat_completion(self):
5551
UserMessage("Hi"),
5652
AssistantMessage("Hello! How can I assist you today?"),
5753
UserMessage("Can you tell me a joke?"),
58-
AssistantMessage(
59-
"Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"),
60-
UserMessage("That's funny. Can you tell me another one?"),
6154
],
62-
configs={
63-
"max_tokens": 10
64-
}
55+
configs={"max_tokens": 10},
6556
)
6657
pytest.assume(max_tokens_res.finish_reason == "length")
6758
pytest.assume(max_tokens_res.message.content)
@@ -70,33 +61,34 @@ async def test_a_chat_completion(self):
7061

7162
# chat completion with stream.
7263

73-
stream_res = await a_chat_completion(model_id=chat_completion_model_id,
74-
messages=[
75-
SystemMessage("You are a professional assistant."),
76-
UserMessage("count from 1 to 50 and separate numbers by comma."),
77-
],
78-
stream=True
79-
)
80-
except_list = [i + 1 for i in range(50)]
81-
real_list = []
64+
stream_res = await a_chat_completion(
65+
model_id=chat_completion_model_id,
66+
messages=[
67+
SystemMessage("You are a professional assistant."),
68+
UserMessage("count from 1 to 10 and separate numbers by comma."),
69+
],
70+
stream=True,
71+
)
72+
except_list = [i + 1 for i in range(10)]
73+
real_str = ""
8274
async for item in stream_res:
8375
if isinstance(item, ChatCompletionChunk):
8476
logger.info(f"Message: {item.delta}")
85-
if item.delta.isdigit():
86-
real_list.append(int(item.delta))
77+
real_str += item.delta
78+
8779
elif isinstance(item, ChatCompletion):
8880
logger.info(f"Message: {item.finish_reason}")
8981
pytest.assume(item.finish_reason == "stop")
82+
83+
real_list = [int(num) for num in re.findall(r"\b\d+\b", real_str)]
9084
pytest.assume(set(except_list) == set(real_list))
9185

9286

9387
@pytest.mark.test_async
9488
class TestTextEmbedding:
95-
9689
@pytest.mark.run(order=0)
9790
@pytest.mark.asyncio
9891
async def test_a_text_embedding(self):
99-
10092
# Text embedding with str.
10193

10294
input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data."

test/testcase/test_async/test_async_retrieval.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ class TestCollection(Base):
2222
"status",
2323
]
2424
collection_keys = set(collection_list)
25-
collection_configs = ["metric", "chunk_size", "chunk_overlap"]
26-
collection_configs_keys = set(collection_configs)
2725

2826
@pytest.mark.run(order=9)
2927
@pytest.mark.asyncio
@@ -43,7 +41,7 @@ async def test_a_create_collection(self):
4341
pytest.assume(res_dict["description"] == description)
4442
pytest.assume(res_dict["embedding_model_id"] == embedding_model_id)
4543
pytest.assume(res_dict["capacity"] == 1000)
46-
pytest.assume(res_dict["status"] == "creating")
44+
pytest.assume((res_dict["status"] == "ready") or (res_dict["status"] == "creating"))
4745

4846
@pytest.mark.run(order=10)
4947
@pytest.mark.asyncio
@@ -76,8 +74,7 @@ async def test_a_get_collection(self, a_collection_id):
7674
res = await a_get_collection(collection_id=self.collection_id)
7775
res_dict = res.to_dict()
7876
pytest.assume(res_dict.keys() == self.collection_keys)
79-
pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys)
80-
pytest.assume(res_dict["status"] == "ready" or "creating")
77+
pytest.assume(res_dict["status"] == "ready" or res_dict["status"] == "creating")
8178

8279
@pytest.mark.run(order=12)
8380
@pytest.mark.asyncio
@@ -89,7 +86,6 @@ async def test_a_update_collection(self):
8986
res = await a_update_collection(collection_id=self.collection_id, name=name, description=description)
9087
res_dict = res.to_dict()
9188
pytest.assume(res_dict.keys() == self.collection_keys)
92-
pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys)
9389
pytest.assume(res_dict["name"] == name)
9490
pytest.assume(res_dict["description"] == description)
9591
pytest.assume(res_dict["status"] == "ready")
@@ -99,17 +95,22 @@ async def test_a_update_collection(self):
9995
async def test_a_delete_collection(self):
10096
# List collections.
10197
old_res = await a_list_collections(order="desc", limit=100, after=None, before=None)
98+
old_nums = len(old_res)
10299

103100
for index, collection in enumerate(old_res):
104101
collection_id = collection.collection_id
105-
# Delete a collection.
102+
# Delete a collection
106103
await a_delete_collection(collection_id=collection_id)
107104

108105
new_collections = await a_list_collections(order="desc", limit=100, after=None, before=None)
109-
# List collections.
106+
107+
# List collections
110108
collection_ids = [c.collection_id for c in new_collections]
111109
pytest.assume(collection_id not in collection_ids)
112110

111+
new_nums = len(new_collections)
112+
pytest.assume(new_nums == old_nums - 1 - index)
113+
113114

114115
@pytest.mark.test_async
115116
class TestRecord(Base):
@@ -125,8 +126,6 @@ class TestRecord(Base):
125126
"status",
126127
]
127128
record_keys = set(record_list)
128-
record_content = ["text"]
129-
record_content_keys = set(record_content)
130129

131130
@pytest.mark.run(order=13)
132131
@pytest.mark.asyncio
@@ -142,9 +141,8 @@ async def test_a_create_record(self):
142141
)
143142
res_dict = res.to_dict()
144143
pytest.assume(res_dict.keys() == self.record_keys)
145-
pytest.assume(res_dict["content"].keys() == self.record_content_keys)
146-
pytest.assume(res_dict["content"]["text"] == text)
147-
pytest.assume(res_dict["status"] == "creating")
144+
pytest.assume(res_dict["content"] == text)
145+
pytest.assume((res_dict["status"] == "creating") or (res_dict["status"] == "ready"))
148146

149147
@pytest.mark.run(order=14)
150148
@pytest.mark.asyncio
@@ -177,12 +175,14 @@ async def test_a_list_records(self, a_record_id):
177175
async def test_a_get_record(self):
178176
# Get a record.
179177

180-
res = await a_get_record(collection_id=self.collection_id, record_id=self.record_id)
181-
logger.info(f"a_get_record:{res}")
182-
res_dict = res.to_dict()
183-
pytest.assume(res_dict.keys() == self.record_keys)
184-
pytest.assume(res_dict["content"].keys() == self.record_content_keys)
185-
pytest.assume(res_dict["status"] == "ready" or "creating")
178+
records = await a_list_records(collection_id=self.collection_id)
179+
for record in records:
180+
record_id = record.record_id
181+
res = await a_get_record(collection_id=self.collection_id, record_id=record_id)
182+
logger.info(f"a_get_record:{res}")
183+
res_dict = res.to_dict()
184+
pytest.assume(res_dict.keys() == self.record_keys)
185+
pytest.assume(res_dict["status"] == "ready" or res_dict["status"] == "creating")
186186

187187
@pytest.mark.run(order=16)
188188
@pytest.mark.asyncio
@@ -194,7 +194,6 @@ async def test_a_update_record(self):
194194
logger.info(f"a_update_record:{res}")
195195
res_dict = res.to_dict()
196196
pytest.assume(res_dict.keys() == self.record_keys)
197-
pytest.assume(res_dict["content"].keys() == self.record_content_keys)
198197
pytest.assume(res_dict["metadata"] == metadata)
199198

200199
@pytest.mark.run(order=34)
@@ -226,7 +225,7 @@ async def test_a_delete_record(self):
226225

227226
@pytest.mark.test_async
228227
class TestChunk(Base):
229-
chunk_list = ["chunk_id", "collection_id", "record_id", "object", "text", "score"]
228+
chunk_list = ["chunk_id", "collection_id", "record_id", "object", "content", "score"]
230229
chunk_keys = set(chunk_list)
231230

232231
@pytest.mark.run(order=17)
@@ -241,5 +240,5 @@ async def test_a_query_chunks(self):
241240
for chunk in res:
242241
chunk_dict = chunk.to_dict()
243242
pytest.assume(chunk_dict.keys() == self.chunk_keys)
244-
pytest.assume(query_text in chunk_dict["text"])
243+
pytest.assume(query_text in chunk_dict["content"])
245244
pytest.assume(chunk_dict["score"] >= 0)

test/testcase/test_async/test_async_tool.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TestAction(Base):
2626
action_keys = set(action_list)
2727
action_schema = ["openapi", "info", "servers", "paths", "components", "security"]
2828
action_schema_keys = set(action_schema)
29-
schema = {
29+
openapi_schema = {
3030
"openapi": "3.1.0",
3131
"info": {
3232
"title": "Get weather data",
@@ -72,7 +72,7 @@ class TestAction(Base):
7272
async def test_a_bulk_create_actions(self):
7373
# Create an action.
7474

75-
res = await a_bulk_create_actions(schema=self.schema)
75+
res = await a_bulk_create_actions(openapi_schema=self.openapi_schema)
7676
for action in res:
7777
action_dict = action.to_dict()
7878
logger.info(action_dict)
@@ -84,15 +84,15 @@ async def test_a_bulk_create_actions(self):
8484
if action_dict["openapi_schema"][key]["/location"] == "get":
8585
pytest.assume(
8686
action_dict["openapi_schema"][key]["/location"]["get"]
87-
== self.schema["paths"]["/location"]["get"]
87+
== self.openapi_schema["paths"]["/location"]["get"]
8888
)
8989
elif action_dict["openapi_schema"][key]["/location"] == "post":
9090
pytest.assume(
9191
action_dict["openapi_schema"][key]["/location"]["post"]
92-
== self.schema["paths"]["/location"]["post"]
92+
== self.openapi_schema["paths"]["/location"]["post"]
9393
)
9494
else:
95-
pytest.assume(action_dict["openapi_schema"][key] == self.schema[key])
95+
pytest.assume(action_dict["openapi_schema"][key] == self.openapi_schema[key])
9696

9797
@pytest.mark.run(order=5)
9898
@pytest.mark.asyncio
@@ -101,10 +101,10 @@ async def test_a_run_action(self, a_action_id):
101101

102102
if not Base.action_id:
103103
Base.action_id = await a_action_id
104-
parameters = {"location": "beijing"}
104+
parameters = {"location": "tokyo"}
105105
res = await a_run_action(action_id=self.action_id, parameters=parameters)
106106
logger.info(f"async run action{res}")
107-
pytest.assume(res["status"] == 400)
107+
pytest.assume(res["status"] != 200)
108108
pytest.assume(res["error"])
109109

110110
@pytest.mark.run(order=6)
@@ -176,7 +176,7 @@ async def test_a_update_action(self):
176176
"security": [],
177177
}
178178

179-
res = await a_update_action(action_id=self.action_id, schema=update_schema)
179+
res = await a_update_action(action_id=self.action_id, openapi_schema=update_schema)
180180
res_dict = res.to_dict()
181181
logger.info(res_dict)
182182
pytest.assume(res_dict.keys() == self.action_keys)

0 commit comments

Comments
 (0)