32
32
is_any_greedy_socket_ready ,
33
33
is_socket_lazy_variadic ,
34
34
)
35
- from haystack .core .pipeline .utils import FIFOPriorityQueue , _deepcopy_with_exceptions , parse_connect_string
35
+ from haystack .core .pipeline .utils import (
36
+ FIFOPriorityQueue ,
37
+ _deepcopy_with_exceptions ,
38
+ args_deprecated ,
39
+ parse_connect_string ,
40
+ )
36
41
from haystack .core .serialization import DeserializationCallbacks , component_from_dict , component_to_dict
37
42
from haystack .core .type_utils import _type_name , _types_are_compatible
38
43
from haystack .marshal import Marshaller , YamlMarshaller
@@ -669,7 +674,14 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di
669
674
}
670
675
return outputs
671
676
672
- def show (self , server_url : str = "https://mermaid.ink" , params : Optional [dict ] = None , timeout : int = 30 ) -> None :
677
+ @args_deprecated
678
+ def show (
679
+ self ,
680
+ server_url : str = "https://mermaid.ink" ,
681
+ params : Optional [dict ] = None ,
682
+ timeout : int = 30 ,
683
+ super_component_expansion : bool = False ,
684
+ ) -> None :
673
685
"""
674
686
Display an image representing this `Pipeline` in a Jupyter notebook.
675
687
@@ -698,20 +710,62 @@ def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] =
698
710
:param timeout:
699
711
Timeout in seconds for the request to the Mermaid server.
700
712
713
+ :param super_component_expansion:
714
+ If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
715
+ super-components as if they were components part of the pipeline instead of a "black-box".
716
+ Otherwise, only the super-component itself will be displayed.
717
+
701
718
:raises PipelineDrawingError:
702
719
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
703
720
"""
721
+
722
+ # Call the internal implementation with keyword arguments
723
+ self ._show_internal (
724
+ server_url = server_url , params = params , timeout = timeout , super_component_expansion = super_component_expansion
725
+ )
726
+
727
+ def _show_internal (
728
+ self ,
729
+ * ,
730
+ server_url : str = "https://mermaid.ink" ,
731
+ params : Optional [dict ] = None ,
732
+ timeout : int = 30 ,
733
+ super_component_expansion : bool = False ,
734
+ ) -> None :
735
+ """
736
+ Internal implementation of show() that uses keyword-only arguments.
737
+
738
+ ToDo: after 2.14.0 release make this the main function and remove the old one.
739
+ """
704
740
if is_in_jupyter ():
705
741
from IPython .display import Image , display # type: ignore
706
742
707
- image_data = _to_mermaid_image (self .graph , server_url = server_url , params = params , timeout = timeout )
743
+ if super_component_expansion :
744
+ graph , super_component_mapping = self ._merge_super_component_pipelines ()
745
+ else :
746
+ graph = self .graph
747
+ super_component_mapping = None
748
+
749
+ image_data = _to_mermaid_image (
750
+ graph ,
751
+ server_url = server_url ,
752
+ params = params ,
753
+ timeout = timeout ,
754
+ super_component_mapping = super_component_mapping ,
755
+ )
708
756
display (Image (image_data ))
709
757
else :
710
758
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
711
759
raise PipelineDrawingError (msg )
712
760
713
- def draw (
714
- self , path : Path , server_url : str = "https://mermaid.ink" , params : Optional [dict ] = None , timeout : int = 30
761
+ @args_deprecated
762
+ def draw ( # pylint: disable=too-many-positional-arguments
763
+ self ,
764
+ path : Path ,
765
+ server_url : str = "https://mermaid.ink" ,
766
+ params : Optional [dict ] = None ,
767
+ timeout : int = 30 ,
768
+ super_component_expansion : bool = False ,
715
769
) -> None :
716
770
"""
717
771
Save an image representing this `Pipeline` to the specified file path.
@@ -720,10 +774,12 @@ def draw(
720
774
721
775
:param path:
722
776
The file path where the generated image will be saved.
777
+
723
778
:param server_url:
724
779
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
725
780
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
726
781
info on how to set up your own Mermaid server.
782
+
727
783
:param params:
728
784
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
729
785
Supported keys:
@@ -741,12 +797,53 @@ def draw(
741
797
:param timeout:
742
798
Timeout in seconds for the request to the Mermaid server.
743
799
800
+ :param super_component_expansion:
801
+ If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
802
+ super-components as if they were components part of the pipeline instead of a "black-box".
803
+ Otherwise, only the super-component itself will be displayed.
804
+
744
805
:raises PipelineDrawingError:
745
806
If there is an issue with rendering or saving the image.
746
807
"""
808
+
809
+ # Call the internal implementation with keyword arguments
810
+ self ._draw_internal (
811
+ path = path ,
812
+ server_url = server_url ,
813
+ params = params ,
814
+ timeout = timeout ,
815
+ super_component_expansion = super_component_expansion ,
816
+ )
817
+
818
+ def _draw_internal (
819
+ self ,
820
+ * ,
821
+ path : Path ,
822
+ server_url : str = "https://mermaid.ink" ,
823
+ params : Optional [dict ] = None ,
824
+ timeout : int = 30 ,
825
+ super_component_expansion : bool = False ,
826
+ ) -> None :
827
+ """
828
+ Internal implementation of draw() that uses keyword-only arguments.
829
+
830
+ ToDo: after 2.14.0 release make this the main function and remove the old one.
831
+ """
747
832
# Before drawing we edit a bit the graph, to avoid modifying the original that is
748
833
# used for running the pipeline we copy it.
749
- image_data = _to_mermaid_image (self .graph , server_url = server_url , params = params , timeout = timeout )
834
+ if super_component_expansion :
835
+ graph , super_component_mapping = self ._merge_super_component_pipelines ()
836
+ else :
837
+ graph = self .graph
838
+ super_component_mapping = None
839
+
840
+ image_data = _to_mermaid_image (
841
+ graph ,
842
+ server_url = server_url ,
843
+ params = params ,
844
+ timeout = timeout ,
845
+ super_component_mapping = super_component_mapping ,
846
+ )
750
847
Path (path ).write_bytes (image_data )
751
848
752
849
def walk (self ) -> Iterator [Tuple [str , Component ]]:
@@ -1175,7 +1272,7 @@ def _write_component_outputs(
1175
1272
for receiver_name , sender_socket , receiver_socket in receivers :
1176
1273
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
1177
1274
# that the sender did not produce an output for this socket.
1178
- # This allows us to track if a pre-decessor already ran but did not produce an output.
1275
+ # This allows us to track if a predecessor already ran but did not produce an output.
1179
1276
value = component_outputs .get (sender_socket .name , _NO_OUTPUT_PRODUCED )
1180
1277
1181
1278
if receiver_name not in inputs :
@@ -1239,6 +1336,99 @@ def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
1239
1336
if candidate is not None and candidate [0 ] == ComponentPriority .BLOCKED :
1240
1337
raise PipelineComponentsBlockedError ()
1241
1338
1339
+ def _find_super_components (self ) -> list [tuple [str , Component ]]:
1340
+ """
1341
+ Find all SuperComponents in the pipeline.
1342
+
1343
+ :returns:
1344
+ List of tuples containing (component_name, component_instance) representing a SuperComponent.
1345
+ """
1346
+
1347
+ super_components = []
1348
+ for comp_name , comp in self .walk ():
1349
+ # a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
1350
+ # we don't test against SuperComponent because doing so always lead to circular imports
1351
+ if hasattr (comp , "pipeline" ) and isinstance (comp .pipeline , self .__class__ ):
1352
+ super_components .append ((comp_name , comp ))
1353
+ return super_components
1354
+
1355
+ def _merge_super_component_pipelines (self ) -> Tuple ["networkx.MultiDiGraph" , Dict [str , str ]]:
1356
+ """
1357
+ Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
1358
+
1359
+ This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
1360
+ and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
1361
+ components are connected to corresponding input and output sockets of the main pipeline.
1362
+
1363
+ :returns:
1364
+ A tuple containing:
1365
+ - A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
1366
+ - A dictionary mapping component names to boolean indicating that this component was part of a
1367
+ SuperComponent
1368
+ - A dictionary mapping component names to their SuperComponent name
1369
+ """
1370
+ merged_graph = self .graph .copy ()
1371
+ super_component_mapping : Dict [str , str ] = {}
1372
+
1373
+ for super_name , super_component in self ._find_super_components ():
1374
+ internal_pipeline = super_component .pipeline # type: ignore
1375
+ internal_graph = internal_pipeline .graph .copy ()
1376
+
1377
+ # Mark all components in the internal pipeline as being part of a SuperComponent
1378
+ for node in internal_graph .nodes ():
1379
+ super_component_mapping [node ] = super_name
1380
+
1381
+ # edges connected to the super component
1382
+ incoming_edges = list (merged_graph .in_edges (super_name , data = True ))
1383
+ outgoing_edges = list (merged_graph .out_edges (super_name , data = True ))
1384
+
1385
+ # merge the SuperComponent graph into the main graph and remove the super component node
1386
+ # since its components are now part of the main graph
1387
+ merged_graph = networkx .compose (merged_graph , internal_graph )
1388
+ merged_graph .remove_node (super_name )
1389
+
1390
+ # get the entry and exit points of the SuperComponent internal pipeline
1391
+ entry_points = [n for n in internal_graph .nodes () if internal_graph .in_degree (n ) == 0 ]
1392
+ exit_points = [n for n in internal_graph .nodes () if internal_graph .out_degree (n ) == 0 ]
1393
+
1394
+ # connect the incoming edges to entry points
1395
+ for sender , _ , edge_data in incoming_edges :
1396
+ sender_socket = edge_data ["from_socket" ]
1397
+ for entry_point in entry_points :
1398
+ # find a matching input socket in the entry point
1399
+ entry_point_sockets = internal_graph .nodes [entry_point ]["input_sockets" ]
1400
+ for socket_name , socket in entry_point_sockets .items ():
1401
+ if _types_are_compatible (sender_socket .type , socket .type , self ._connection_type_validation ):
1402
+ merged_graph .add_edge (
1403
+ sender ,
1404
+ entry_point ,
1405
+ key = f"{ sender_socket .name } /{ socket_name } " ,
1406
+ conn_type = _type_name (sender_socket .type ),
1407
+ from_socket = sender_socket ,
1408
+ to_socket = socket ,
1409
+ mandatory = socket .is_mandatory ,
1410
+ )
1411
+
1412
+ # connect outgoing edges from exit points
1413
+ for _ , receiver , edge_data in outgoing_edges :
1414
+ receiver_socket = edge_data ["to_socket" ]
1415
+ for exit_point in exit_points :
1416
+ # find a matching output socket in the exit point
1417
+ exit_point_sockets = internal_graph .nodes [exit_point ]["output_sockets" ]
1418
+ for socket_name , socket in exit_point_sockets .items ():
1419
+ if _types_are_compatible (socket .type , receiver_socket .type , self ._connection_type_validation ):
1420
+ merged_graph .add_edge (
1421
+ exit_point ,
1422
+ receiver ,
1423
+ key = f"{ socket_name } /{ receiver_socket .name } " ,
1424
+ conn_type = _type_name (socket .type ),
1425
+ from_socket = socket ,
1426
+ to_socket = receiver_socket ,
1427
+ mandatory = receiver_socket .is_mandatory ,
1428
+ )
1429
+
1430
+ return merged_graph , super_component_mapping
1431
+
1242
1432
1243
1433
def _connections_status (
1244
1434
sender_node : str , receiver_node : str , sender_sockets : List [OutputSocket ], receiver_sockets : List [InputSocket ]
0 commit comments