Skip to content

Commit 4a5e4d3

Browse files
feat: return common type in SuperComponent type compatibility check (#9275)
* feat: return common type in SuperComponent type compatibility check * fix test_utils * address review comments * update tests * use typing module types * refactor * refactor * unenforce type check * refactor --------- Co-authored-by: Michele Pangrazzi <xmikex83@gmail.com>
1 parent 167229f commit 4a5e4d3

File tree

5 files changed

+290
-60
lines changed

5 files changed

+290
-60
lines changed

haystack/core/super_component/super_component.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,20 @@ def _resolve_input_types_from_mapping(
211211
aggregated_inputs[wrapper_input_name]["default"] = _delegate_default
212212
continue
213213

214-
if not _is_compatible(existing_socket_info["type"], socket_info["type"]):
214+
is_compatible, common_type = _is_compatible(existing_socket_info["type"], socket_info["type"])
215+
216+
if not is_compatible:
215217
raise InvalidMappingTypeError(
216218
f"Type conflict for input '{socket_name}' from component '{comp_name}'. "
217219
f"Existing type: {existing_socket_info['type']}, new type: {socket_info['type']}."
218220
)
219221

222+
# Use the common type for the aggregated input
223+
aggregated_inputs[wrapper_input_name]["type"] = common_type
224+
220225
# If any socket requires mandatory inputs then the aggregated input is also considered mandatory.
221226
# So we use the type of the mandatory input and remove the default value if it exists.
222227
if socket_info["is_mandatory"]:
223-
aggregated_inputs[wrapper_input_name]["type"] = socket_info["type"]
224228
aggregated_inputs[wrapper_input_name].pop("default", None)
225229

226230
return aggregated_inputs

haystack/core/super_component/utils.py

Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Annotated, Any, TypeVar, Union, get_args, get_origin
5+
from typing import Annotated, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast, get_args, get_origin
66

77
from haystack.core.component.types import HAYSTACK_GREEDY_VARIADIC_ANNOTATION, HAYSTACK_VARIADIC_ANNOTATION
88

@@ -14,33 +14,39 @@ class _delegate_default:
1414
T = TypeVar("T")
1515

1616

17-
def _is_compatible(type1: T, type2: T, unwrap_nested: bool = True) -> bool:
17+
def _is_compatible(type1: T, type2: T, unwrap_nested: bool = True) -> Tuple[bool, Optional[T]]:
1818
"""
1919
Check if two types are compatible (bidirectional/symmetric check).
2020
2121
:param type1: First type to compare
2222
:param type2: Second type to compare
2323
:param unwrap_nested: If True, recursively unwraps nested Optional and Variadic types.
2424
If False, only unwraps at the top level.
25-
:return: True if types are compatible, False otherwise
25+
:return: Tuple of (True if types are compatible, common type if compatible)
2626
"""
2727
type1_unwrapped = _unwrap_all(type1, recursive=unwrap_nested)
2828
type2_unwrapped = _unwrap_all(type2, recursive=unwrap_nested)
2929

3030
return _types_are_compatible(type1_unwrapped, type2_unwrapped)
3131

3232

33-
def _types_are_compatible(type1: T, type2: T) -> bool:
33+
def _types_are_compatible(type1: T, type2: T) -> Tuple[bool, Optional[T]]:
3434
"""
3535
Core type compatibility check implementing symmetric matching.
3636
3737
:param type1: First unwrapped type to compare
3838
:param type2: Second unwrapped type to compare
3939
:return: True if types are compatible, False otherwise
4040
"""
41-
# Handle Any type and direct equality
42-
if type1 is Any or type2 is Any or type1 == type2:
43-
return True
41+
# Handle Any type
42+
if type1 is Any:
43+
return True, _convert_to_typing_type(type2)
44+
if type2 is Any:
45+
return True, _convert_to_typing_type(type1)
46+
47+
# Direct equality
48+
if type1 == type2:
49+
return True, _convert_to_typing_type(type1)
4450

4551
type1_origin = get_origin(type1)
4652
type2_origin = get_origin(type2)
@@ -53,34 +59,84 @@ def _types_are_compatible(type1: T, type2: T) -> bool:
5359
return _check_non_union_compatibility(type1, type2, type1_origin, type2_origin)
5460

5561

56-
def _check_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool:
62+
def _check_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> Tuple[bool, Optional[T]]:
5763
"""Handle all Union type compatibility cases."""
5864
if type1_origin is Union and type2_origin is not Union:
59-
return any(_types_are_compatible(union_arg, type2) for union_arg in get_args(type1))
60-
if type2_origin is Union and type1_origin is not Union:
61-
return any(_types_are_compatible(type1, union_arg) for union_arg in get_args(type2))
62-
# Both are Union types. Check all type combinations are compatible.
63-
return any(any(_types_are_compatible(arg1, arg2) for arg2 in get_args(type2)) for arg1 in get_args(type1))
64-
65+
# Find all compatible types from the union
66+
compatible_types = []
67+
for union_arg in get_args(type1):
68+
is_compat, common = _types_are_compatible(union_arg, type2)
69+
if is_compat and common is not None:
70+
compatible_types.append(common)
71+
if compatible_types:
72+
# The constructed Union or single type must be cast to Optional[T]
73+
# to satisfy mypy, as T is specific to this function's call context.
74+
result_type = Union[tuple(compatible_types)] if len(compatible_types) > 1 else compatible_types[0]
75+
return True, cast(Optional[T], result_type)
76+
return False, None
6577

66-
def _check_non_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool:
78+
if type2_origin is Union and type1_origin is not Union:
79+
# Find all compatible types from the union
80+
compatible_types = []
81+
for union_arg in get_args(type2):
82+
is_compat, common = _types_are_compatible(type1, union_arg)
83+
if is_compat and common is not None:
84+
compatible_types.append(common)
85+
if compatible_types:
86+
# The constructed Union or single type must be cast to Optional[T]
87+
# to satisfy mypy, as T is specific to this function's call context.
88+
result_type = Union[tuple(compatible_types)] if len(compatible_types) > 1 else compatible_types[0]
89+
return True, cast(Optional[T], result_type)
90+
return False, None
91+
92+
# Both are Union types
93+
compatible_types = []
94+
for arg1 in get_args(type1):
95+
for arg2 in get_args(type2):
96+
is_compat, common = _types_are_compatible(arg1, arg2)
97+
if is_compat and common is not None:
98+
compatible_types.append(common)
99+
100+
if compatible_types:
101+
# The constructed Union or single type must be cast to Optional[T]
102+
# to satisfy mypy, as T is specific to this function's call context.
103+
result_type = Union[tuple(compatible_types)] if len(compatible_types) > 1 else compatible_types[0]
104+
return True, cast(Optional[T], result_type)
105+
return False, None
106+
107+
108+
def _check_non_union_compatibility(
109+
type1: T, type2: T, type1_origin: Any, type2_origin: Any
110+
) -> Tuple[bool, Optional[T]]:
67111
"""Handle non-Union type compatibility cases."""
68112
# If no origin, compare types directly
69113
if not type1_origin and not type2_origin:
70-
return type1 == type2
114+
if type1 == type2:
115+
return True, type1
116+
return False, None
71117

72118
# Both must have origins and they must be equal
73119
if not (type1_origin and type2_origin and type1_origin == type2_origin):
74-
return False
120+
return False, None
75121

76122
# Compare generic type arguments
77123
type1_args = get_args(type1)
78124
type2_args = get_args(type2)
79125

80126
if len(type1_args) != len(type2_args):
81-
return False
127+
return False, None
128+
129+
# Check if all arguments are compatible
130+
common_args = []
131+
for t1_arg, t2_arg in zip(type1_args, type2_args):
132+
is_compat, common = _types_are_compatible(t1_arg, t2_arg)
133+
if not is_compat:
134+
return False, None
135+
common_args.append(common)
82136

83-
return all(_types_are_compatible(t1_arg, t2_arg) for t1_arg, t2_arg in zip(type1_args, type2_args))
137+
# Reconstruct the type with common arguments
138+
typing_type = _convert_to_typing_type(type1_origin)
139+
return True, cast(Optional[T], typing_type[tuple(common_args)])
84140

85141

86142
def _unwrap_all(t: T, recursive: bool) -> T:
@@ -167,3 +223,37 @@ def _unwrap_optionals(t: T, recursive: bool) -> T:
167223
if recursive:
168224
return _unwrap_all(result, recursive) # type: ignore
169225
return result # type: ignore
226+
227+
228+
def _convert_to_typing_type(t: Any) -> Any:
229+
"""
230+
Convert built-in Python types to their typing equivalents.
231+
232+
:param t: Type to convert
233+
:return: The type using typing module types
234+
"""
235+
origin = get_origin(t)
236+
args = get_args(t)
237+
238+
# Mapping of built-in types to their typing equivalents
239+
type_converters = {
240+
list: lambda: List if not args else List[Any],
241+
dict: lambda: Dict if not args else Dict[Any, Any],
242+
set: lambda: Set if not args else Set[Any],
243+
tuple: lambda: Tuple if not args else Tuple[Any, ...],
244+
}
245+
246+
# Recursive argument handling
247+
if origin in type_converters:
248+
result = type_converters[origin]()
249+
if args:
250+
if origin == list:
251+
return List[_convert_to_typing_type(args[0])] # type: ignore
252+
if origin == dict:
253+
return Dict[_convert_to_typing_type(args[0]), _convert_to_typing_type(args[1])] # type: ignore
254+
if origin == set:
255+
return Set[_convert_to_typing_type(args[0])] # type: ignore
256+
if origin == tuple:
257+
return Tuple[tuple(_convert_to_typing_type(arg) for arg in args)]
258+
return result
259+
return t
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Enhance SuperComponent's type compatibility check to return the detected common type between two input types.

test/core/super_component/test_super_component.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from typing import List
4+
from typing import Any, List, Optional, Union
55

66
import pytest
77
from haystack import Document, SuperComponent, Pipeline, AsyncPipeline, component, super_component
@@ -366,3 +366,70 @@ def test_draw_with_default_parameters(self, mock_draw, sample_super_component, t
366366

367367
sample_super_component.draw(path=path)
368368
mock_draw.assert_called_once_with(path=path, server_url="https://mermaid.ink", params=None, timeout=30)
369+
370+
def test_input_types_reconciliation(self):
371+
"""Test that input types are properly reconciled when they are compatible but not identical."""
372+
373+
@component
374+
class TypeTestComponent:
375+
@component.output_types(result_int=int, result_any=Any)
376+
def run(self, input_int: int, input_any: Any):
377+
return {"result_int": input_int, "result_any": input_any}
378+
379+
pipeline = Pipeline()
380+
pipeline.add_component("test1", TypeTestComponent())
381+
pipeline.add_component("test2", TypeTestComponent())
382+
383+
input_mapping = {"number": ["test1.input_int", "test2.input_any"]}
384+
output_mapping = {"test2.result_int": "result_int"}
385+
wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping)
386+
387+
input_sockets = wrapper.__haystack_input__._sockets_dict
388+
assert "number" in input_sockets
389+
assert input_sockets["number"].type == int
390+
391+
def test_union_type_reconciliation(self):
392+
"""Test that Union types are properly reconciled when creating a SuperComponent."""
393+
394+
@component
395+
class UnionTypeComponent1:
396+
@component.output_types(result=Union[int, str])
397+
def run(self, input: Union[int, str]):
398+
return {"result": input}
399+
400+
@component
401+
class UnionTypeComponent2:
402+
@component.output_types(result=Union[float, str])
403+
def run(self, input: Union[float, str]):
404+
return {"result": input}
405+
406+
pipeline = Pipeline()
407+
pipeline.add_component("test1", UnionTypeComponent1())
408+
pipeline.add_component("test2", UnionTypeComponent2())
409+
410+
input_mapping = {"data": ["test1.input", "test2.input"]}
411+
output_mapping = {"test2.result": "result"}
412+
wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping)
413+
414+
input_sockets = wrapper.__haystack_input__._sockets_dict
415+
assert "data" in input_sockets
416+
assert input_sockets["data"].type == Union[str]
417+
418+
def test_input_types_with_any(self):
419+
"""Test that Any type is properly handled when reconciling types."""
420+
421+
@component
422+
class AnyTypeComponent:
423+
@component.output_types(result=str)
424+
def run(self, specific: str, generic: Any):
425+
return {"result": specific}
426+
427+
pipeline = Pipeline()
428+
pipeline.add_component("test", AnyTypeComponent())
429+
430+
input_mapping = {"text": ["test.specific", "test.generic"]}
431+
wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping)
432+
433+
input_sockets = wrapper.__haystack_input__._sockets_dict
434+
assert "text" in input_sockets
435+
assert input_sockets["text"].type == str

0 commit comments

Comments
 (0)