Skip to content

Commit 3342f17

Browse files
davidsbatistasjrl
andauthored
feat: draw/show SuperComponents in detail, expand it and show it's internal components in the visualisation diagram (#9389)
* initial import * small fixes * adding tests * adding tests * refactoring merge graphs * updating tests * docstrings * adding release notes * adding SuperComponent name to extended components * adding colours and legend to different super components * adding missed docstring parameter * fixing tests and type checking * Update haystack/core/pipeline/base.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * forcing keyword arguments for draw() and show() * adding wrapper function and a deprecation warning * adding pylint disable - this will be removed soon * wip * adding a decorator function to test if another function is being called with positional arguments * adding a decorator function to test if another function is being called with positional arguments --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
1 parent ba41696 commit 3342f17

File tree

7 files changed

+575
-17
lines changed

7 files changed

+575
-17
lines changed

haystack/core/pipeline/base.py

Lines changed: 197 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
is_any_greedy_socket_ready,
3333
is_socket_lazy_variadic,
3434
)
35-
from haystack.core.pipeline.utils import FIFOPriorityQueue, _deepcopy_with_exceptions, parse_connect_string
35+
from haystack.core.pipeline.utils import (
36+
FIFOPriorityQueue,
37+
_deepcopy_with_exceptions,
38+
args_deprecated,
39+
parse_connect_string,
40+
)
3641
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
3742
from haystack.core.type_utils import _type_name, _types_are_compatible
3843
from haystack.marshal import Marshaller, YamlMarshaller
@@ -669,7 +674,14 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di
669674
}
670675
return outputs
671676

672-
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30) -> None:
677+
@args_deprecated
678+
def show(
679+
self,
680+
server_url: str = "https://mermaid.ink",
681+
params: Optional[dict] = None,
682+
timeout: int = 30,
683+
super_component_expansion: bool = False,
684+
) -> None:
673685
"""
674686
Display an image representing this `Pipeline` in a Jupyter notebook.
675687
@@ -698,20 +710,62 @@ def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] =
698710
:param timeout:
699711
Timeout in seconds for the request to the Mermaid server.
700712
713+
:param super_component_expansion:
714+
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
715+
super-components as if they were components part of the pipeline instead of a "black-box".
716+
Otherwise, only the super-component itself will be displayed.
717+
701718
:raises PipelineDrawingError:
702719
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
703720
"""
721+
722+
# Call the internal implementation with keyword arguments
723+
self._show_internal(
724+
server_url=server_url, params=params, timeout=timeout, super_component_expansion=super_component_expansion
725+
)
726+
727+
def _show_internal(
728+
self,
729+
*,
730+
server_url: str = "https://mermaid.ink",
731+
params: Optional[dict] = None,
732+
timeout: int = 30,
733+
super_component_expansion: bool = False,
734+
) -> None:
735+
"""
736+
Internal implementation of show() that uses keyword-only arguments.
737+
738+
ToDo: after 2.14.0 release make this the main function and remove the old one.
739+
"""
704740
if is_in_jupyter():
705741
from IPython.display import Image, display # type: ignore
706742

707-
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
743+
if super_component_expansion:
744+
graph, super_component_mapping = self._merge_super_component_pipelines()
745+
else:
746+
graph = self.graph
747+
super_component_mapping = None
748+
749+
image_data = _to_mermaid_image(
750+
graph,
751+
server_url=server_url,
752+
params=params,
753+
timeout=timeout,
754+
super_component_mapping=super_component_mapping,
755+
)
708756
display(Image(image_data))
709757
else:
710758
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
711759
raise PipelineDrawingError(msg)
712760

713-
def draw(
714-
self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30
761+
@args_deprecated
762+
def draw( # pylint: disable=too-many-positional-arguments
763+
self,
764+
path: Path,
765+
server_url: str = "https://mermaid.ink",
766+
params: Optional[dict] = None,
767+
timeout: int = 30,
768+
super_component_expansion: bool = False,
715769
) -> None:
716770
"""
717771
Save an image representing this `Pipeline` to the specified file path.
@@ -720,10 +774,12 @@ def draw(
720774
721775
:param path:
722776
The file path where the generated image will be saved.
777+
723778
:param server_url:
724779
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
725780
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
726781
info on how to set up your own Mermaid server.
782+
727783
:param params:
728784
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
729785
Supported keys:
@@ -741,12 +797,53 @@ def draw(
741797
:param timeout:
742798
Timeout in seconds for the request to the Mermaid server.
743799
800+
:param super_component_expansion:
801+
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
802+
super-components as if they were components part of the pipeline instead of a "black-box".
803+
Otherwise, only the super-component itself will be displayed.
804+
744805
:raises PipelineDrawingError:
745806
If there is an issue with rendering or saving the image.
746807
"""
808+
809+
# Call the internal implementation with keyword arguments
810+
self._draw_internal(
811+
path=path,
812+
server_url=server_url,
813+
params=params,
814+
timeout=timeout,
815+
super_component_expansion=super_component_expansion,
816+
)
817+
818+
def _draw_internal(
819+
self,
820+
*,
821+
path: Path,
822+
server_url: str = "https://mermaid.ink",
823+
params: Optional[dict] = None,
824+
timeout: int = 30,
825+
super_component_expansion: bool = False,
826+
) -> None:
827+
"""
828+
Internal implementation of draw() that uses keyword-only arguments.
829+
830+
ToDo: after 2.14.0 release make this the main function and remove the old one.
831+
"""
747832
# Before drawing we edit a bit the graph, to avoid modifying the original that is
748833
# used for running the pipeline we copy it.
749-
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
834+
if super_component_expansion:
835+
graph, super_component_mapping = self._merge_super_component_pipelines()
836+
else:
837+
graph = self.graph
838+
super_component_mapping = None
839+
840+
image_data = _to_mermaid_image(
841+
graph,
842+
server_url=server_url,
843+
params=params,
844+
timeout=timeout,
845+
super_component_mapping=super_component_mapping,
846+
)
750847
Path(path).write_bytes(image_data)
751848

752849
def walk(self) -> Iterator[Tuple[str, Component]]:
@@ -1175,7 +1272,7 @@ def _write_component_outputs(
11751272
for receiver_name, sender_socket, receiver_socket in receivers:
11761273
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
11771274
# that the sender did not produce an output for this socket.
1178-
# This allows us to track if a pre-decessor already ran but did not produce an output.
1275+
# This allows us to track if a predecessor already ran but did not produce an output.
11791276
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
11801277

11811278
if receiver_name not in inputs:
@@ -1239,6 +1336,99 @@ def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
12391336
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
12401337
raise PipelineComponentsBlockedError()
12411338

1339+
def _find_super_components(self) -> list[tuple[str, Component]]:
1340+
"""
1341+
Find all SuperComponents in the pipeline.
1342+
1343+
:returns:
1344+
List of tuples containing (component_name, component_instance) representing a SuperComponent.
1345+
"""
1346+
1347+
super_components = []
1348+
for comp_name, comp in self.walk():
1349+
# a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
1350+
# we don't test against SuperComponent because doing so always lead to circular imports
1351+
if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__):
1352+
super_components.append((comp_name, comp))
1353+
return super_components
1354+
1355+
def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]:
1356+
"""
1357+
Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
1358+
1359+
This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
1360+
and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
1361+
components are connected to corresponding input and output sockets of the main pipeline.
1362+
1363+
:returns:
1364+
A tuple containing:
1365+
- A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
1366+
- A dictionary mapping component names to boolean indicating that this component was part of a
1367+
SuperComponent
1368+
- A dictionary mapping component names to their SuperComponent name
1369+
"""
1370+
merged_graph = self.graph.copy()
1371+
super_component_mapping: Dict[str, str] = {}
1372+
1373+
for super_name, super_component in self._find_super_components():
1374+
internal_pipeline = super_component.pipeline # type: ignore
1375+
internal_graph = internal_pipeline.graph.copy()
1376+
1377+
# Mark all components in the internal pipeline as being part of a SuperComponent
1378+
for node in internal_graph.nodes():
1379+
super_component_mapping[node] = super_name
1380+
1381+
# edges connected to the super component
1382+
incoming_edges = list(merged_graph.in_edges(super_name, data=True))
1383+
outgoing_edges = list(merged_graph.out_edges(super_name, data=True))
1384+
1385+
# merge the SuperComponent graph into the main graph and remove the super component node
1386+
# since its components are now part of the main graph
1387+
merged_graph = networkx.compose(merged_graph, internal_graph)
1388+
merged_graph.remove_node(super_name)
1389+
1390+
# get the entry and exit points of the SuperComponent internal pipeline
1391+
entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0]
1392+
exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0]
1393+
1394+
# connect the incoming edges to entry points
1395+
for sender, _, edge_data in incoming_edges:
1396+
sender_socket = edge_data["from_socket"]
1397+
for entry_point in entry_points:
1398+
# find a matching input socket in the entry point
1399+
entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
1400+
for socket_name, socket in entry_point_sockets.items():
1401+
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
1402+
merged_graph.add_edge(
1403+
sender,
1404+
entry_point,
1405+
key=f"{sender_socket.name}/{socket_name}",
1406+
conn_type=_type_name(sender_socket.type),
1407+
from_socket=sender_socket,
1408+
to_socket=socket,
1409+
mandatory=socket.is_mandatory,
1410+
)
1411+
1412+
# connect outgoing edges from exit points
1413+
for _, receiver, edge_data in outgoing_edges:
1414+
receiver_socket = edge_data["to_socket"]
1415+
for exit_point in exit_points:
1416+
# find a matching output socket in the exit point
1417+
exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
1418+
for socket_name, socket in exit_point_sockets.items():
1419+
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
1420+
merged_graph.add_edge(
1421+
exit_point,
1422+
receiver,
1423+
key=f"{socket_name}/{receiver_socket.name}",
1424+
conn_type=_type_name(socket.type),
1425+
from_socket=socket,
1426+
to_socket=receiver_socket,
1427+
mandatory=receiver_socket.is_mandatory,
1428+
)
1429+
1430+
return merged_graph, super_component_mapping
1431+
12421432

12431433
def _connections_status(
12441434
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]

0 commit comments

Comments
 (0)