/
common.py
103 lines (82 loc) · 3.04 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Common utils for embeddings."""
import json
import re
import uuid
from typing import Dict, List, Tuple
from tqdm import tqdm
from llama_index.bridge.pydantic import BaseModel
from llama_index.llms.utils import LLM
from llama_index.schema import MetadataMode, TextNode
class EmbeddingQAFinetuneDataset(BaseModel):
"""Embedding QA Finetuning Dataset.
Args:
queries (Dict[str, str]): Dict id -> query.
corpus (Dict[str, str]): Dict id -> string.
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
"""
queries: Dict[str, str] # dict id -> query
corpus: Dict[str, str] # dict id -> string
relevant_docs: Dict[str, List[str]] # query id -> list of doc ids
mode: str = "text"
@property
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
"""Get query, relevant doc ids."""
return [
(query, self.relevant_docs[query_id])
for query_id, query in self.queries.items()
]
def save_json(self, path: str) -> None:
"""Save json."""
with open(path, "w") as f:
json.dump(self.dict(), f, indent=4)
@classmethod
def from_json(cls, path: str) -> "EmbeddingQAFinetuneDataset":
"""Load json."""
with open(path) as f:
data = json.load(f)
return cls(**data)
DEFAULT_QA_GENERATE_PROMPT_TMPL = """\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge.
generate only questions based on the below query.
You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Restrict the questions to the \
context information provided."
"""
# generate queries as a convenience function
def generate_qa_embedding_pairs(
nodes: List[TextNode],
llm: LLM,
qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> EmbeddingQAFinetuneDataset:
"""Generate examples given a set of nodes."""
node_dict = {
node.node_id: node.get_content(metadata_mode=MetadataMode.NONE)
for node in nodes
}
queries = {}
relevant_docs = {}
for node_id, text in tqdm(node_dict.items()):
query = qa_generate_prompt_tmpl.format(
context_str=text, num_questions_per_chunk=num_questions_per_chunk
)
response = llm.complete(query)
result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]
for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [node_id]
# construct dataset
return EmbeddingQAFinetuneDataset(
queries=queries, corpus=node_dict, relevant_docs=relevant_docs
)