-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
query.py
672 lines (570 loc) · 23.9 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
"""Query Pipeline."""
import json
import uuid
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
get_args,
)
import networkx
from llama_index.legacy.async_utils import run_jobs
from llama_index.legacy.bridge.pydantic import Field
from llama_index.legacy.callbacks import CallbackManager
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
from llama_index.legacy.core.query_pipeline.query_component import (
QUERY_COMPONENT_TYPE,
ChainableMixin,
InputKeys,
Link,
OutputKeys,
QueryComponent,
)
from llama_index.legacy.utils import print_text
def get_output(
src_key: Optional[str],
output_dict: Dict[str, Any],
) -> Any:
"""Add input to module deps inputs."""
# get relevant output from link
if src_key is None:
# ensure that output_dict only has one key
if len(output_dict) != 1:
raise ValueError("Output dict must have exactly one key.")
output = next(iter(output_dict.values()))
else:
output = output_dict[src_key]
return output
def add_output_to_module_inputs(
dest_key: str,
output: Any,
module: QueryComponent,
module_inputs: Dict[str, Any],
) -> None:
"""Add input to module deps inputs."""
# now attach output to relevant input key for module
if dest_key is None:
free_keys = module.free_req_input_keys
# ensure that there is only one remaining key given partials
if len(free_keys) != 1:
raise ValueError(
"Module input keys must have exactly one key if "
"dest_key is not specified. Remaining keys: "
f"in module: {free_keys}"
)
module_inputs[next(iter(free_keys))] = output
else:
module_inputs[dest_key] = output
def print_debug_input(
module_key: str,
input: Dict[str, Any],
val_str_len: int = 200,
) -> None:
"""Print debug input."""
output = f"> Running module {module_key} with input: \n"
for key, value in input.items():
# stringify and truncate output
val_str = (
str(value)[:val_str_len] + "..."
if len(str(value)) > val_str_len
else str(value)
)
output += f"{key}: {val_str}\n"
print_text(output + "\n", color="llama_lavender")
def print_debug_input_multi(
module_keys: List[str],
module_inputs: List[Dict[str, Any]],
val_str_len: int = 200,
) -> None:
"""Print debug input."""
output = f"> Running modules and inputs in parallel: \n"
for module_key, input in zip(module_keys, module_inputs):
cur_output = f"Module key: {module_key}. Input: \n"
for key, value in input.items():
# stringify and truncate output
val_str = (
str(value)[:val_str_len] + "..."
if len(str(value)) > val_str_len
else str(value)
)
cur_output += f"{key}: {val_str}\n"
output += cur_output + "\n"
print_text(output + "\n", color="llama_lavender")
# Function to clean non-serializable attributes and return a copy of the graph
# https://stackoverflow.com/questions/23268421/networkx-how-to-access-attributes-of-objects-as-nodes
def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
# Create a deep copy of the graph to preserve the original
graph_copy = graph.copy()
# Iterate over nodes and clean attributes
for node, attributes in graph_copy.nodes(data=True):
for key, value in list(attributes.items()):
if callable(value): # Checks if the value is a function
del attributes[key] # Remove the attribute if it's non-serializable
# Similarly, you can extend this to clean edge attributes if necessary
for u, v, attributes in graph_copy.edges(data=True):
for key, value in list(attributes.items()):
if callable(value): # Checks if the value is a function
del attributes[key] # Remove the attribute if it's non-serializable
return graph_copy
CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str]
class QueryPipeline(QueryComponent):
"""A query pipeline that can allow arbitrary chaining of different modules.
A pipeline itself is a query component, and can be used as a module in another pipeline.
"""
callback_manager: CallbackManager = Field(
default_factory=lambda: CallbackManager([]), exclude=True
)
module_dict: Dict[str, QueryComponent] = Field(
default_factory=dict, description="The modules in the pipeline."
)
dag: networkx.MultiDiGraph = Field(
default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline."
)
verbose: bool = Field(
default=False, description="Whether to print intermediate steps."
)
show_progress: bool = Field(
default=False,
description="Whether to show progress bar (currently async only).",
)
num_workers: int = Field(
default=4, description="Number of workers to use (currently async only)."
)
class Config:
arbitrary_types_allowed = True
def __init__(
self,
callback_manager: Optional[CallbackManager] = None,
chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
links: Optional[List[Link]] = None,
**kwargs: Any,
):
super().__init__(
callback_manager=callback_manager or CallbackManager([]),
**kwargs,
)
self._init_graph(chain=chain, modules=modules, links=links)
def _init_graph(
self,
chain: Optional[Sequence[CHAIN_COMPONENT_TYPE]] = None,
modules: Optional[Dict[str, QUERY_COMPONENT_TYPE]] = None,
links: Optional[List[Link]] = None,
) -> None:
"""Initialize graph."""
if chain is not None:
if modules is not None or links is not None:
raise ValueError("Cannot specify both chain and modules/links in init.")
self.add_chain(chain)
elif modules is not None:
self.add_modules(modules)
if links is not None:
for link in links:
self.add_link(**link.dict())
def add_chain(self, chain: Sequence[CHAIN_COMPONENT_TYPE]) -> None:
"""Add a chain of modules to the pipeline.
This is a special form of pipeline that is purely sequential/linear.
This allows a more concise way of specifying a pipeline.
"""
# first add all modules
module_keys: List[str] = []
for module in chain:
if isinstance(module, get_args(QUERY_COMPONENT_TYPE)):
module_key = str(uuid.uuid4())
self.add(module_key, cast(QUERY_COMPONENT_TYPE, module))
module_keys.append(module_key)
elif isinstance(module, str):
module_keys.append(module)
else:
raise ValueError("Chain must be a sequence of modules or module keys.")
# then add all links
for i in range(len(chain) - 1):
self.add_link(src=module_keys[i], dest=module_keys[i + 1])
def add_links(
self,
links: List[Link],
) -> None:
"""Add links to the pipeline."""
for link in links:
if isinstance(link, Link):
self.add_link(**link.dict())
else:
raise ValueError("Link must be of type `Link` or `ConditionalLinks`.")
def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None:
"""Add modules to the pipeline."""
for module_key, module in module_dict.items():
self.add(module_key, module)
def add(self, module_key: str, module: QUERY_COMPONENT_TYPE) -> None:
"""Add a module to the pipeline."""
# if already exists, raise error
if module_key in self.module_dict:
raise ValueError(f"Module {module_key} already exists in pipeline.")
if isinstance(module, ChainableMixin):
module = module.as_query_component()
else:
pass
self.module_dict[module_key] = cast(QueryComponent, module)
self.dag.add_node(module_key)
def add_link(
self,
src: str,
dest: str,
src_key: Optional[str] = None,
dest_key: Optional[str] = None,
condition_fn: Optional[Callable] = None,
input_fn: Optional[Callable] = None,
) -> None:
"""Add a link between two modules."""
if src not in self.module_dict:
raise ValueError(f"Module {src} does not exist in pipeline.")
self.dag.add_edge(
src,
dest,
src_key=src_key,
dest_key=dest_key,
condition_fn=condition_fn,
input_fn=input_fn,
)
def get_root_keys(self) -> List[str]:
"""Get root keys."""
return self._get_root_keys()
def get_leaf_keys(self) -> List[str]:
"""Get leaf keys."""
return self._get_leaf_keys()
def _get_root_keys(self) -> List[str]:
"""Get root keys."""
return [v for v, d in self.dag.in_degree() if d == 0]
def _get_leaf_keys(self) -> List[str]:
"""Get leaf keys."""
# get all modules without downstream dependencies
return [v for v, d in self.dag.out_degree() if d == 0]
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# go through every module in module dict and set callback manager
self.callback_manager = callback_manager
for module in self.module_dict.values():
module.set_callback_manager(callback_manager)
def run(
self,
*args: Any,
return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> Any:
"""Run the pipeline."""
# first set callback manager
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
# try to get query payload
try:
query_payload = json.dumps(kwargs)
except TypeError:
query_payload = json.dumps(str(kwargs))
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload}
) as query_event:
return self._run(
*args, return_values_direct=return_values_direct, **kwargs
)
def run_multi(
self,
module_input_dict: Dict[str, Any],
callback_manager: Optional[CallbackManager] = None,
) -> Dict[str, Any]:
"""Run the pipeline for multiple roots."""
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
with self.callback_manager.event(
CBEventType.QUERY,
payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
) as query_event:
return self._run_multi(module_input_dict)
async def arun(
self,
*args: Any,
return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> Any:
"""Run the pipeline."""
# first set callback manager
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
try:
query_payload = json.dumps(kwargs)
except TypeError:
query_payload = json.dumps(str(kwargs))
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_payload}
) as query_event:
return await self._arun(
*args, return_values_direct=return_values_direct, **kwargs
)
async def arun_multi(
self,
module_input_dict: Dict[str, Any],
callback_manager: Optional[CallbackManager] = None,
) -> Dict[str, Any]:
"""Run the pipeline for multiple roots."""
callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager)
with self.callback_manager.as_trace("query"):
with self.callback_manager.event(
CBEventType.QUERY,
payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
) as query_event:
return await self._arun_multi(module_input_dict)
def _get_root_key_and_kwargs(
self, *args: Any, **kwargs: Any
) -> Tuple[str, Dict[str, Any]]:
"""Get root key and kwargs.
This is for `_run`.
"""
## run pipeline
## assume there is only one root - for multiple roots, need to specify `run_multi`
root_keys = self._get_root_keys()
if len(root_keys) != 1:
raise ValueError("Only one root is supported.")
root_key = root_keys[0]
root_module = self.module_dict[root_key]
if len(args) > 0:
# if args is specified, validate. only one arg is allowed, and there can only be one free
# input key in the module
if len(args) > 1:
raise ValueError("Only one arg is allowed.")
if len(kwargs) > 0:
raise ValueError("No kwargs allowed if args is specified.")
if len(root_module.free_req_input_keys) != 1:
raise ValueError("Only one free input key is allowed.")
# set kwargs
kwargs[next(iter(root_module.free_req_input_keys))] = args[0]
return root_key, kwargs
def _get_single_result_output(
self,
result_outputs: Dict[str, Any],
return_values_direct: bool,
) -> Any:
"""Get result output from a single module.
If output dict is a single key, return the value directly
if return_values_direct is True.
"""
if len(result_outputs) != 1:
raise ValueError("Only one output is supported.")
result_output = next(iter(result_outputs.values()))
# return_values_direct: if True, return the value directly
# without the key
# if it's a dict with one key, return the value
if (
isinstance(result_output, dict)
and len(result_output) == 1
and return_values_direct
):
return next(iter(result_output.values()))
else:
return result_output
def _run(self, *args: Any, return_values_direct: bool = True, **kwargs: Any) -> Any:
"""Run the pipeline.
Assume that there is a single root module and a single output module.
For multi-input and multi-outputs, please see `run_multi`.
"""
root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
# call run_multi with one root key
result_outputs = self._run_multi({root_key: kwargs})
return self._get_single_result_output(result_outputs, return_values_direct)
async def _arun(
self, *args: Any, return_values_direct: bool = True, **kwargs: Any
) -> Any:
"""Run the pipeline.
Assume that there is a single root module and a single output module.
For multi-input and multi-outputs, please see `run_multi`.
"""
root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
# call run_multi with one root key
result_outputs = await self._arun_multi({root_key: kwargs})
return self._get_single_result_output(result_outputs, return_values_direct)
def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None:
root_keys = self._get_root_keys()
# if root keys don't match up with kwargs keys, raise error
if set(root_keys) != set(module_input_dict.keys()):
raise ValueError(
"Expected root keys do not match up with input keys.\n"
f"Expected root keys: {root_keys}\n"
f"Input keys: {module_input_dict.keys()}\n"
)
def _process_component_output(
self,
queue: List[str],
output_dict: Dict[str, Any],
module_key: str,
all_module_inputs: Dict[str, Dict[str, Any]],
result_outputs: Dict[str, Any],
) -> List[str]:
"""Process component output."""
new_queue = queue.copy()
# if there's no more edges, add result to output
if module_key in self._get_leaf_keys():
result_outputs[module_key] = output_dict
else:
edge_list = list(self.dag.edges(module_key, data=True))
# everything not in conditional_edge_list is regular
for _, dest, attr in edge_list:
output = get_output(attr.get("src_key"), output_dict)
# if input_fn is not None, use it to modify the input
if attr["input_fn"] is not None:
dest_output = attr["input_fn"](output)
else:
dest_output = output
add_edge = True
if attr["condition_fn"] is not None:
conditional_val = attr["condition_fn"](output)
if not conditional_val:
add_edge = False
if add_edge:
add_output_to_module_inputs(
attr.get("dest_key"),
dest_output,
self.module_dict[dest],
all_module_inputs[dest],
)
else:
# remove dest from queue
new_queue.remove(dest)
return new_queue
def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.
kwargs is in the form of module_dict -> input_dict
input_dict is in the form of input_key -> input
"""
self._validate_inputs(module_input_dict)
queue = list(networkx.topological_sort(self.dag))
# module_deps_inputs is a dict to collect inputs for a module
# mapping of module_key -> dict of input_key -> input
# initialize with blank dict for every module key
# the input dict of each module key will be populated as the upstream modules are run
all_module_inputs: Dict[str, Dict[str, Any]] = {
module_key: {} for module_key in self.module_dict
}
result_outputs: Dict[str, Any] = {}
# add root inputs to all_module_inputs
for module_key, module_input in module_input_dict.items():
all_module_inputs[module_key] = module_input
while len(queue) > 0:
module_key = queue.pop(0)
module = self.module_dict[module_key]
module_input = all_module_inputs[module_key]
if self.verbose:
print_debug_input(module_key, module_input)
output_dict = module.run_component(**module_input)
# get new nodes and is_leaf
queue = self._process_component_output(
queue, output_dict, module_key, all_module_inputs, result_outputs
)
return result_outputs
async def _arun_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.
kwargs is in the form of module_dict -> input_dict
input_dict is in the form of input_key -> input
"""
self._validate_inputs(module_input_dict)
queue = list(networkx.topological_sort(self.dag))
# module_deps_inputs is a dict to collect inputs for a module
# mapping of module_key -> dict of input_key -> input
# initialize with blank dict for every module key
# the input dict of each module key will be populated as the upstream modules are run
all_module_inputs: Dict[str, Dict[str, Any]] = {
module_key: {} for module_key in self.module_dict
}
result_outputs: Dict[str, Any] = {}
# add root inputs to all_module_inputs
for module_key, module_input in module_input_dict.items():
all_module_inputs[module_key] = module_input
while len(queue) > 0:
popped_indices = set()
popped_nodes = []
# get subset of nodes who don't have ancestors also in the queue
# these are tasks that are parallelizable
for i, module_key in enumerate(queue):
module_ancestors = networkx.ancestors(self.dag, module_key)
if len(set(module_ancestors).intersection(queue)) == 0:
popped_indices.add(i)
popped_nodes.append(module_key)
# update queue
queue = [
module_key
for i, module_key in enumerate(queue)
if i not in popped_indices
]
if self.verbose:
print_debug_input_multi(
popped_nodes,
[all_module_inputs[module_key] for module_key in popped_nodes],
)
# create tasks from popped nodes
tasks = []
for module_key in popped_nodes:
module = self.module_dict[module_key]
module_input = all_module_inputs[module_key]
tasks.append(module.arun_component(**module_input))
# run tasks
output_dicts = await run_jobs(
tasks, show_progress=self.show_progress, workers=self.num_workers
)
for output_dict, module_key in zip(output_dicts, popped_nodes):
# get new nodes and is_leaf
queue = self._process_component_output(
queue, output_dict, module_key, all_module_inputs, result_outputs
)
return result_outputs
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
raise NotImplementedError
def validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs."""
return input
def _validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError
def validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component outputs."""
# NOTE: we override this to do nothing
return output
def _run_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
return self.run(return_values_direct=False, **kwargs)
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run component."""
return await self.arun(return_values_direct=False, **kwargs)
@property
def input_keys(self) -> InputKeys:
"""Input keys."""
# get input key of first module
root_keys = self._get_root_keys()
if len(root_keys) != 1:
raise ValueError("Only one root is supported.")
root_module = self.module_dict[root_keys[0]]
return root_module.input_keys
@property
def output_keys(self) -> OutputKeys:
"""Output keys."""
# get output key of last module
leaf_keys = self._get_leaf_keys()
if len(leaf_keys) != 1:
raise ValueError("Only one leaf is supported.")
leaf_module = self.module_dict[leaf_keys[0]]
return leaf_module.output_keys
@property
def sub_query_components(self) -> List[QueryComponent]:
"""Sub query components."""
return list(self.module_dict.values())
@property
def clean_dag(self) -> networkx.DiGraph:
"""Clean dag."""
return clean_graph_attributes_copy(self.dag)