-
Notifications
You must be signed in to change notification settings - Fork 274
/
semantic_search.py
60 lines (49 loc) · 1.58 KB
/
semantic_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from pgml import Collection, Pipeline
from datasets import load_dataset
from time import time
from dotenv import load_dotenv
from rich.console import Console
import asyncio
async def main():
load_dotenv()
console = Console()
# Initialize collection
collection = Collection("quora_collection")
# Create and add pipeline
pipeline = Pipeline(
"quorav1",
{
"text": {
"splitter": {"model": "recursive_character"},
"semantic_search": {"model": "Alibaba-NLP/gte-base-en-v1.5"},
}
},
)
await collection.add_pipeline(pipeline)
# Prep documents for upserting
dataset = load_dataset("quora", split="train")
questions = []
for record in dataset["questions"]:
questions.extend(record["text"])
# Remove duplicates and add id
documents = []
for i, question in enumerate(list(set(questions))):
if question:
documents.append({"id": i, "text": question})
# Upsert documents
await collection.upsert_documents(documents[:2000])
# Query
query = "What is a good mobile os?"
console.print("Querying for %s..." % query)
start = time()
results = await collection.vector_search(
{"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline
)
end = time()
console.print("\n Results for '%s' " % (query), style="bold")
console.print(results)
console.print("Query time = %0.3f" % (end - start))
# Archive collection
await collection.archive()
if __name__ == "__main__":
asyncio.run(main())