2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
- from typing import Annotated , Any , TypeVar , Union , get_args , get_origin
5
+ from typing import Annotated , Any , Dict , List , Optional , Set , Tuple , TypeVar , Union , cast , get_args , get_origin
6
6
7
7
from haystack .core .component .types import HAYSTACK_GREEDY_VARIADIC_ANNOTATION , HAYSTACK_VARIADIC_ANNOTATION
8
8
@@ -14,33 +14,39 @@ class _delegate_default:
14
14
T = TypeVar ("T" )
15
15
16
16
17
- def _is_compatible (type1 : T , type2 : T , unwrap_nested : bool = True ) -> bool :
17
+ def _is_compatible (type1 : T , type2 : T , unwrap_nested : bool = True ) -> Tuple [ bool , Optional [ T ]] :
18
18
"""
19
19
Check if two types are compatible (bidirectional/symmetric check).
20
20
21
21
:param type1: First type to compare
22
22
:param type2: Second type to compare
23
23
:param unwrap_nested: If True, recursively unwraps nested Optional and Variadic types.
24
24
If False, only unwraps at the top level.
25
- :return: True if types are compatible, False otherwise
25
+ :return: Tuple of ( True if types are compatible, common type if compatible)
26
26
"""
27
27
type1_unwrapped = _unwrap_all (type1 , recursive = unwrap_nested )
28
28
type2_unwrapped = _unwrap_all (type2 , recursive = unwrap_nested )
29
29
30
30
return _types_are_compatible (type1_unwrapped , type2_unwrapped )
31
31
32
32
33
- def _types_are_compatible (type1 : T , type2 : T ) -> bool :
33
+ def _types_are_compatible (type1 : T , type2 : T ) -> Tuple [ bool , Optional [ T ]] :
34
34
"""
35
35
Core type compatibility check implementing symmetric matching.
36
36
37
37
:param type1: First unwrapped type to compare
38
38
:param type2: Second unwrapped type to compare
39
39
:return: True if types are compatible, False otherwise
40
40
"""
41
- # Handle Any type and direct equality
42
- if type1 is Any or type2 is Any or type1 == type2 :
43
- return True
41
+ # Handle Any type
42
+ if type1 is Any :
43
+ return True , _convert_to_typing_type (type2 )
44
+ if type2 is Any :
45
+ return True , _convert_to_typing_type (type1 )
46
+
47
+ # Direct equality
48
+ if type1 == type2 :
49
+ return True , _convert_to_typing_type (type1 )
44
50
45
51
type1_origin = get_origin (type1 )
46
52
type2_origin = get_origin (type2 )
@@ -53,34 +59,84 @@ def _types_are_compatible(type1: T, type2: T) -> bool:
53
59
return _check_non_union_compatibility (type1 , type2 , type1_origin , type2_origin )
54
60
55
61
56
- def _check_union_compatibility (type1 : T , type2 : T , type1_origin : Any , type2_origin : Any ) -> bool :
62
+ def _check_union_compatibility (type1 : T , type2 : T , type1_origin : Any , type2_origin : Any ) -> Tuple [ bool , Optional [ T ]] :
57
63
"""Handle all Union type compatibility cases."""
58
64
if type1_origin is Union and type2_origin is not Union :
59
- return any (_types_are_compatible (union_arg , type2 ) for union_arg in get_args (type1 ))
60
- if type2_origin is Union and type1_origin is not Union :
61
- return any (_types_are_compatible (type1 , union_arg ) for union_arg in get_args (type2 ))
62
- # Both are Union types. Check all type combinations are compatible.
63
- return any (any (_types_are_compatible (arg1 , arg2 ) for arg2 in get_args (type2 )) for arg1 in get_args (type1 ))
64
-
65
+ # Find all compatible types from the union
66
+ compatible_types = []
67
+ for union_arg in get_args (type1 ):
68
+ is_compat , common = _types_are_compatible (union_arg , type2 )
69
+ if is_compat and common is not None :
70
+ compatible_types .append (common )
71
+ if compatible_types :
72
+ # The constructed Union or single type must be cast to Optional[T]
73
+ # to satisfy mypy, as T is specific to this function's call context.
74
+ result_type = Union [tuple (compatible_types )] if len (compatible_types ) > 1 else compatible_types [0 ]
75
+ return True , cast (Optional [T ], result_type )
76
+ return False , None
65
77
66
- def _check_non_union_compatibility (type1 : T , type2 : T , type1_origin : Any , type2_origin : Any ) -> bool :
78
+ if type2_origin is Union and type1_origin is not Union :
79
+ # Find all compatible types from the union
80
+ compatible_types = []
81
+ for union_arg in get_args (type2 ):
82
+ is_compat , common = _types_are_compatible (type1 , union_arg )
83
+ if is_compat and common is not None :
84
+ compatible_types .append (common )
85
+ if compatible_types :
86
+ # The constructed Union or single type must be cast to Optional[T]
87
+ # to satisfy mypy, as T is specific to this function's call context.
88
+ result_type = Union [tuple (compatible_types )] if len (compatible_types ) > 1 else compatible_types [0 ]
89
+ return True , cast (Optional [T ], result_type )
90
+ return False , None
91
+
92
+ # Both are Union types
93
+ compatible_types = []
94
+ for arg1 in get_args (type1 ):
95
+ for arg2 in get_args (type2 ):
96
+ is_compat , common = _types_are_compatible (arg1 , arg2 )
97
+ if is_compat and common is not None :
98
+ compatible_types .append (common )
99
+
100
+ if compatible_types :
101
+ # The constructed Union or single type must be cast to Optional[T]
102
+ # to satisfy mypy, as T is specific to this function's call context.
103
+ result_type = Union [tuple (compatible_types )] if len (compatible_types ) > 1 else compatible_types [0 ]
104
+ return True , cast (Optional [T ], result_type )
105
+ return False , None
106
+
107
+
108
+ def _check_non_union_compatibility (
109
+ type1 : T , type2 : T , type1_origin : Any , type2_origin : Any
110
+ ) -> Tuple [bool , Optional [T ]]:
67
111
"""Handle non-Union type compatibility cases."""
68
112
# If no origin, compare types directly
69
113
if not type1_origin and not type2_origin :
70
- return type1 == type2
114
+ if type1 == type2 :
115
+ return True , type1
116
+ return False , None
71
117
72
118
# Both must have origins and they must be equal
73
119
if not (type1_origin and type2_origin and type1_origin == type2_origin ):
74
- return False
120
+ return False , None
75
121
76
122
# Compare generic type arguments
77
123
type1_args = get_args (type1 )
78
124
type2_args = get_args (type2 )
79
125
80
126
if len (type1_args ) != len (type2_args ):
81
- return False
127
+ return False , None
128
+
129
+ # Check if all arguments are compatible
130
+ common_args = []
131
+ for t1_arg , t2_arg in zip (type1_args , type2_args ):
132
+ is_compat , common = _types_are_compatible (t1_arg , t2_arg )
133
+ if not is_compat :
134
+ return False , None
135
+ common_args .append (common )
82
136
83
- return all (_types_are_compatible (t1_arg , t2_arg ) for t1_arg , t2_arg in zip (type1_args , type2_args ))
137
+ # Reconstruct the type with common arguments
138
+ typing_type = _convert_to_typing_type (type1_origin )
139
+ return True , cast (Optional [T ], typing_type [tuple (common_args )])
84
140
85
141
86
142
def _unwrap_all (t : T , recursive : bool ) -> T :
@@ -167,3 +223,37 @@ def _unwrap_optionals(t: T, recursive: bool) -> T:
167
223
if recursive :
168
224
return _unwrap_all (result , recursive ) # type: ignore
169
225
return result # type: ignore
226
+
227
+
228
+ def _convert_to_typing_type (t : Any ) -> Any :
229
+ """
230
+ Convert built-in Python types to their typing equivalents.
231
+
232
+ :param t: Type to convert
233
+ :return: The type using typing module types
234
+ """
235
+ origin = get_origin (t )
236
+ args = get_args (t )
237
+
238
+ # Mapping of built-in types to their typing equivalents
239
+ type_converters = {
240
+ list : lambda : List if not args else List [Any ],
241
+ dict : lambda : Dict if not args else Dict [Any , Any ],
242
+ set : lambda : Set if not args else Set [Any ],
243
+ tuple : lambda : Tuple if not args else Tuple [Any , ...],
244
+ }
245
+
246
+ # Recursive argument handling
247
+ if origin in type_converters :
248
+ result = type_converters [origin ]()
249
+ if args :
250
+ if origin == list :
251
+ return List [_convert_to_typing_type (args [0 ])] # type: ignore
252
+ if origin == dict :
253
+ return Dict [_convert_to_typing_type (args [0 ]), _convert_to_typing_type (args [1 ])] # type: ignore
254
+ if origin == set :
255
+ return Set [_convert_to_typing_type (args [0 ])] # type: ignore
256
+ if origin == tuple :
257
+ return Tuple [tuple (_convert_to_typing_type (arg ) for arg in args )]
258
+ return result
259
+ return t
0 commit comments