/
func_wrapper.py
1630 lines (1361 loc) · 55.3 KB
/
func_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import contextlib
import ivy
import functools
import logging
import weakref
import warnings
import copy as python_copy
from types import FunctionType
from typing import Callable
import inspect
import numpy as np
# for wrapping (sequence matters)
FN_DECORATORS = [
"infer_device",
"infer_dtype",
"handle_array_function",
"integer_arrays_to_float",
"outputs_to_ivy_arrays",
"outputs_to_ivy_shapes",
"outputs_to_native_arrays",
"inputs_to_native_arrays",
"inputs_to_native_shapes",
"inputs_to_ivy_arrays",
"handle_out_argument",
"handle_view_indexing",
"handle_view",
"handle_array_like_without_promotion",
"handle_mixed_function",
"handle_nestable",
"handle_exceptions",
"handle_nans",
]
# Helpers #
# --------#
# for casting modes, order is the hierarchy
casting_modes_dict = {
"uint": lambda: ivy.valid_uint_dtypes,
"int": lambda: sorted(
tuple(set(ivy.valid_int_dtypes).difference(set(ivy.valid_uint_dtypes)))
),
"float": lambda: ivy.valid_float_dtypes,
"complex": lambda: ivy.valid_complex_dtypes,
}
def caster(dtype, intersect):
if hasattr(dtype, "dtype"):
dtype = ivy.as_ivy_dtype(dtype.dtype)
else:
dtype = ivy.as_ivy_dtype(dtype)
if str(dtype) in intersect:
# based on upcasting or downcasting do something
if ivy.cast_dtypes():
# all casting types is enabled
# check cross_casting
ret_dtype = cross_caster(intersect)
if ret_dtype:
return ret_dtype
# check upcasting
ret_dtype = upcaster(dtype, intersect)
if ret_dtype:
return ret_dtype
# check downcasting
ret_dtype = downcaster(dtype, intersect)
if ret_dtype:
return ret_dtype
elif ivy.crosscast_dtypes:
# check cross_casting
ret_dtype = cross_caster(intersect)
if ret_dtype:
return ret_dtype
elif ivy.upcast_dtypes:
# check upcasting
ret_dtype = upcaster(dtype, intersect)
if ret_dtype:
return ret_dtype
elif ivy.downcast_dtypes:
# check downcasting
ret_dtype = downcaster(dtype, intersect)
if ret_dtype:
return ret_dtype
def upcaster(dtype, intersect):
# upcasting is enabled, we upcast to the highest
if "uint" in str(dtype):
index = casting_modes_dict["uint"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["uint"]()):
if casting_modes_dict["uint"]()[index] not in intersect:
result = casting_modes_dict["uint"]()[index]
break
index += 1
return result
if "int" in dtype:
index = casting_modes_dict["int"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["int"]()):
if casting_modes_dict["int"]()[index] not in intersect:
result = casting_modes_dict["int"]()[index]
break
index += 1
return result
if "float" in dtype:
index = casting_modes_dict["float"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["float"]()):
if casting_modes_dict["float"]()[index] not in intersect:
result = casting_modes_dict["float"]()[index]
break
index += 1
return result
if "complex" in dtype:
index = casting_modes_dict["complex"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["complex"]()):
if casting_modes_dict["complex"]()[index] not in intersect:
result = casting_modes_dict["complex"]()[index]
break
index += 1
return result
def downcaster(dtype, intersect):
# downcasting is enabled, we upcast to the highest
if "uint" in str(dtype):
index = casting_modes_dict["uint"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["int"]()[index] not in intersect:
result = casting_modes_dict["uint"]()[index]
break
index -= 1
return result
if "int" in dtype:
index = casting_modes_dict["int"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["int"]()[index] not in intersect:
result = casting_modes_dict["int"]()[index]
break
index -= 1
return result
if "float" in dtype:
index = casting_modes_dict["float"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["float"]()[index] not in intersect:
result = casting_modes_dict["float"]()[index]
break
index -= 1
return result
if "complex" in dtype:
index = casting_modes_dict["complex"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["complex"]()[index] not in intersect:
result = casting_modes_dict["complex"]()[index]
break
index -= 1
return result
def cross_caster(intersect):
# check if this is an integer unsupported case
dtype = ""
if intersect == ivy.valid_int_dtypes:
# make dtype equal to default float
dtype = ivy.default_float_dtype()
elif intersect == ivy.valid_float_dtypes:
# make dtype equal to default int
dtype = ivy.default_int_dtype()
return str(dtype)
def try_array_function_override(func, overloaded_args, types, args, kwargs):
if not overloaded_args:
return False, None
for overloaded_arg in overloaded_args:
# Note that we're only calling __ivy_array_function__ on the *first*
# occurence of each argument type. This is necessary for reasonable
# performance with a possibly long list of overloaded arguments, for
# which each __ivy_array_function__ implementation might reasonably need to
# check all argument types.
try:
result = overloaded_arg.__ivy_array_function__(func, types, args, kwargs)
except Exception:
raise ivy.utils.exceptions.IvyNotImplementedException
if result is not NotImplemented:
return True, result
raise TypeError(
"no implementation found for {} on types that implement "
"__ivy_array_function__: {}".format(func, list(map(type, overloaded_args)))
)
def _get_first_array(*args, **kwargs):
# ToDo: make this more efficient, with function ivy.nested_nth_index_where
arr = None
if args:
arr_idxs = ivy.nested_argwhere(args, ivy.is_array, stop_after_n_found=1)
if arr_idxs:
arr = ivy.index_nest(args, arr_idxs[0])
else:
arr_idxs = ivy.nested_argwhere(kwargs, ivy.is_array, stop_after_n_found=1)
if arr_idxs:
arr = ivy.index_nest(kwargs, arr_idxs[0])
elif kwargs:
arr_idxs = ivy.nested_argwhere(kwargs, ivy.is_array, stop_after_n_found=1)
if arr_idxs:
arr = ivy.index_nest(kwargs, arr_idxs[0])
return arr
def _build_view(original, view, fn, args, kwargs, index=None):
if ivy.exists(original._base):
if ivy.backend in ("jax", "tensorflow"):
warnings.warn(
"Creating many views will lead to overhead "
"when performing inplace updates with this backend"
)
base = original._base
view._base = base
view._manipulation_stack = python_copy.copy(original._manipulation_stack)
else:
base = original
view._base = base
base._view_refs.append(weakref.ref(view))
view._manipulation_stack.append((fn, args[1:], kwargs, index))
# Handle attributes for torch functions without native view functionality
if ivy.exists(original._torch_base):
view._torch_base = (
original
if ivy.exists(original._torch_manipulation)
else original._torch_base
)
else:
view._torch_base = base
if fn in _torch_non_native_view_functions:
view._torch_manipulation = (original, (fn, args[1:], kwargs))
view._torch_base._torch_view_refs.append(weakref.ref(view))
return view
_torch_non_native_view_functions = ("flip", "flipud", "rot90", "fliplr")
def _check_in_nested_sequence(sequence, value=None, _type=None):
"""
Check `sequence` for either a `value` or a value of type `_type`.
Helper to recursively check if a N-level nested `sequence` contains
either a `value` or contains a value of type `_type` and return a
boolean flag.
"""
if sequence is value or (isinstance(sequence, _type)):
# Base case - N = 0
return True
elif isinstance(sequence, (tuple, list)):
if any(isinstance(_val, _type) or _val is value for _val in sequence):
# N = 1
return True
else:
return any(
_check_in_nested_sequence(sub_sequence, value, _type)
for sub_sequence in sequence
if isinstance(sub_sequence, (tuple, list))
)
# Array Handling #
# ---------------#
def handle_array_function(fn):
"""
Wrap a function `fn` to be passed to array_function method.
Wrap a function to extract the relevant argument types to be passed
to array_function method.
"""
@functools.wraps(fn)
def _handle_array_function(*args, **kwargs):
overloaded_types = []
overloaded_args = []
for arg in args + tuple(kwargs.values()):
if ivy.exists(arg) and (
not isinstance(arg, ivy.Container)
and hasattr(arg, "__ivy_array_function__")
):
if type(arg) not in overloaded_types:
overloaded_types.append(type(arg))
if (
arg.__ivy_array_function__
is not ivy.Array.__ivy_array_function__
and not isinstance(arg, (ivy.Array, ivy.NativeArray))
):
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(type(arg), type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
if ivy.exists(arg) and isinstance(arg, ivy.Container):
arg = ivy.Container.cont_flatten_key_chains(arg)
indices = ivy.nested_argwhere(
arg, lambda x: hasattr(x, "__ivy_array_function__")
)
for a in indices:
if type(getattr(arg, a[0])) not in overloaded_types:
overloaded_types.append(type(getattr(arg, a[0])))
if getattr(
arg, a[0]
).__ivy_array_function__ is not ivy.Array.__ivy_array_function__ and not isinstance( # noqa: E501
getattr(arg, a[0]), (ivy.Array, ivy.NativeArray)
):
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(type(getattr(arg, a[0])), type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
success, value = try_array_function_override(
ivy.__dict__[fn.__name__], overloaded_args, overloaded_types, args, kwargs
)
if success:
return value
return fn(*args, **kwargs)
_handle_array_function.handle_array_function = True
return _handle_array_function
def handle_array_like_without_promotion(fn: Callable) -> Callable:
@functools.wraps(fn)
def _handle_array_like_without_promotion(*args, **kwargs):
args = list(args)
num_args = len(args)
try:
type_hints = inspect.signature(fn).parameters
except (TypeError, ValueError):
return fn(*args, **kwargs)
parameters = list(type_hints.keys())
annotations = [param.annotation for param in type_hints.values()]
for i, (annotation, parameter, arg) in enumerate(
zip(annotations, parameters, args)
):
annotation_str = str(annotation)
if (
("rray" in annotation_str or "Tensor" in annotation_str)
and parameter != "out"
and all(
sq not in annotation_str
for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"]
)
):
if i < num_args:
# Fix for ellipsis, slices for numpy's __getitem__
# No need to try and convert them into arrays
# since asarray throws unpredictable bugs
if _check_in_nested_sequence(arg, value=Ellipsis, _type=slice):
continue
if not ivy.is_array(arg):
args[i] = ivy.array(arg)
elif parameters in kwargs:
kwarg = kwargs[parameter]
if not ivy.is_array(kwarg):
kwargs[parameter] = ivy.array(kwarg)
return fn(*args, **kwargs)
_handle_array_like_without_promotion.handle_array_like_without_promotion = True
return _handle_array_like_without_promotion
def inputs_to_native_arrays(fn: Callable) -> Callable:
@functools.wraps(fn)
def _inputs_to_native_arrays(*args, **kwargs):
"""
Convert all `ivy.Array` instances in both the positional and keyword arguments
into `ivy.NativeArray` instances, and then calls the function with the updated
arguments.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with native arrays passed in the arguments.
"""
if not ivy.get_array_mode():
return fn(*args, **kwargs)
# check if kwargs contains an out argument, and if so, remove it
has_out = False
out = None
if "out" in kwargs:
out = kwargs["out"]
del kwargs["out"]
has_out = True
# convert all arrays in the inputs to ivy.NativeArray instances
new_args, new_kwargs = ivy.args_to_native(*args, **kwargs)
# add the original out argument back to the keyword arguments
if has_out:
new_kwargs["out"] = out
return fn(*new_args, **new_kwargs)
_inputs_to_native_arrays.inputs_to_native_arrays = True
return _inputs_to_native_arrays
def inputs_to_ivy_arrays(fn: Callable) -> Callable:
@functools.wraps(fn)
def _inputs_to_ivy_arrays(*args, **kwargs):
"""
Convert all `ivy.NativeArray` instances in both the positional and keyword
arguments into `ivy.Array` instances, and then calls the function with the
updated arguments.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with ivy arrays passed in the arguments.
"""
has_out = False
if "out" in kwargs:
out = kwargs["out"]
has_out = True
# convert all arrays in the inputs to ivy.Array instances
ivy_args, ivy_kwargs = ivy.args_to_ivy(
*args, **kwargs, include_derived={tuple: True}
)
if has_out:
ivy_kwargs["out"] = out
return fn(*ivy_args, **ivy_kwargs)
_inputs_to_ivy_arrays.inputs_to_ivy_arrays = True
return _inputs_to_ivy_arrays
def inputs_to_native_shapes(fn: Callable) -> Callable:
@functools.wraps(fn)
def _inputs_to_native_shapes(*args, **kwargs):
args, kwargs = ivy.nested_map(
[args, kwargs],
lambda x: (
x.shape if isinstance(x, ivy.Shape) and ivy.get_array_mode() else x
),
)
return fn(*args, **kwargs)
_inputs_to_native_shapes.inputs_to_native_shapes = True
return _inputs_to_native_shapes
def outputs_to_ivy_shapes(fn: Callable) -> Callable:
@functools.wraps(fn)
def _outputs_to_ivy_shapes(*args, **kwargs):
args, kwargs = ivy.nested_map(
[args, kwargs],
lambda x: (
x.shape if isinstance(x, ivy.Shape) and ivy.get_array_mode() else x
),
)
return fn(*args, **kwargs)
_outputs_to_ivy_shapes.outputs_to_ivy_shapes = True
return _outputs_to_ivy_shapes
def to_native_shapes_and_back(fn: Callable) -> Callable:
"""
Make `fn` receive `ivy.NativeShape` and return `ivy.Shape`.
Wrap `fn` so that input shapes are all converted to
`ivy.NativeShape` instances and return shapes are all converted to
`ivy.Shape` instances.
"""
return outputs_to_ivy_shapes(inputs_to_native_shapes(fn))
def outputs_to_ivy_arrays(fn: Callable) -> Callable:
@functools.wraps(fn)
def _outputs_to_ivy_arrays(*args, **kwargs):
"""
Call the function, and then converts all `ivy.NativeArray` instances in the
function return into `ivy.Array` instances.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with native arrays as ivy arrays.
"""
# call unmodified function
ret = fn(*args, **kwargs)
# convert all arrays in the return to `ivy.Array` instances
return (
ivy.to_ivy(ret, nested=True, include_derived={tuple: True})
if ivy.get_array_mode()
else ret
)
_outputs_to_ivy_arrays.outputs_to_ivy_arrays = True
return _outputs_to_ivy_arrays
def output_to_native_arrays(fn: Callable) -> Callable:
"""
Call the function, and then converts all `ivy.Array` instances in the function
return into `ivy.NativeArray` instances.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with ivy arrays as native arrays.
"""
@functools.wraps(fn)
def _output_to_native_arrays(*args, **kwargs):
ret = fn(*args, **kwargs)
return ivy.to_native(ret, nested=True, include_derived={tuple: True})
_output_to_native_arrays.outputs_to_native_arrays = True
return _output_to_native_arrays
def to_ivy_arrays_and_back(fn: Callable) -> Callable:
"""
Make `fn` receive `ivy.Array` and return `ivy.NativeArray`.
Wrap `fn` so that input arrays are all converted to `ivy.Array`
instances and return arrays are all converted to `ivy.NativeArray`
instances.
"""
return output_to_native_arrays(inputs_to_ivy_arrays(fn))
def to_native_arrays_and_back(fn: Callable) -> Callable:
"""
Make `fn` receive `ivy.NativeArray` and return `ivy.Array`.
Wrap `fn` so that input arrays are all converted to
`ivy.NativeArray` instances and return arrays are all converted to
`ivy.Array` instances.
"""
return outputs_to_ivy_arrays(inputs_to_native_arrays(fn))
def frontend_outputs_to_ivy_arrays(fn: Callable) -> Callable:
"""
Wrap `fn` and convert all frontend arrays in its return to ivy arrays.
Used in cases when a frontend function receives a callable (frontend
function) argument. To be able to use that callable in a composition
of ivy functions, its outputs need to be converted to ivy arrays.
"""
@functools.wraps(fn)
def _outputs_to_ivy_arrays(*args, **kwargs):
ret = fn(*args, **kwargs)
return ivy.nested_map(
ret,
lambda x: x.ivy_array if hasattr(x, "ivy_array") else x,
shallow=False,
)
return _outputs_to_ivy_arrays
def handle_view(fn: Callable) -> Callable:
"""
Wrap `fn` and performs view handling if copy is False.
Used for functional backends (Jax and TensorFlow). Checks if the
first arg is a view or original array by checking if the ._base
attribute is populated. If it's original it adds the returned array
to its view references, then the returned array adds the operation
to its manipulation stack and stores the original as its base. If
the first arg is a view, then the returned array copies its base and
manipulation stack, appends the new operation to the manipulation
stack and appends its reference to the base array's view_refs
attribute.
"""
@functools.wraps(fn)
def _handle_view(*args, **kwargs):
ret = fn(*args, **kwargs)
if ("copy" in kwargs and kwargs["copy"]) or not ivy.is_ivy_array(args[0]):
return ret
original = args[0]
if isinstance(ret, (list, tuple)):
for i, view in enumerate(ret):
ret[i] = _build_view(original, view, fn.__name__, args, kwargs, i)
else:
ret = _build_view(original, ret, fn.__name__, args, kwargs, None)
return ret
_handle_view.handle_view = True
return _handle_view
def handle_view_indexing(fn: Callable) -> Callable:
"""
Wrap `fn` and performs view handling specifically for indexing.
As with NumPy it returns a copy if advanced indexing is performed.
Used for functional backends (Jax and TensorFlow). Checks if the
first arg is a view or original array by checking if the ._base
attribute is populated. If it's original it adds the returned array
to its view references, then the returned array adds the operation
to its manipulation stack and stores the original as its base. If
the first arg is a view, then the returned array copies its base and
manipulation stack, appends the new operation to the manipulation
stack and appends its reference to the base array's view_refs
attribute.
"""
@functools.wraps(fn)
def _handle_view_indexing(*args, **kwargs):
ret = fn(*args, **kwargs)
if ("copy" in kwargs and kwargs["copy"]) or not ivy.is_ivy_array(args[0]):
return ret
query = kwargs["query"] if "query" in kwargs else args[1]
query = (query,) if not isinstance(query, tuple) else query
if [i for i in query if not isinstance(i, (slice, int))]:
return ret
original = args[0]
# ToDo: Remove hard coding of only function with this wrapper
# Need general way to convert special method to function found in ivy.__dict__
ret = _build_view(original, ret, "get_item", args, kwargs)
return ret
_handle_view_indexing.handle_view_indexing = True
return _handle_view_indexing
def _convert_numpy_arrays_to_backend_specific(*args):
if isinstance(args, np.ndarray):
np_arr_idxs = ivy.nested_argwhere(args, lambda x: isinstance(x, np.ndarray))
np_arr_val = ivy.multi_index_nest(args, np_arr_idxs)
backend_arr_vals = [ivy.array(x).to_native() for x in np_arr_val]
ivy.set_nest_at_indices(args, np_arr_idxs, backend_arr_vals)
return args
def handle_numpy_arrays_in_specific_backend(fn: Callable) -> Callable:
"""
Wrap `fn` and converts all `numpy.ndarray` inputs to `torch.Tensor` instances.
Used for functional backends (PyTorch). Converts all `numpy.ndarray`
inputs to `torch.Tensor` instances.
"""
@functools.wraps(fn)
def _handle_numpy_array_in_torch(*args, **kwargs):
args = _convert_numpy_arrays_to_backend_specific(*args)
ret = fn(*args, **kwargs)
return ret
_handle_numpy_array_in_torch.handle_numpy_arrays_in_specific_backend = True
return _handle_numpy_array_in_torch
# Data Type Handling #
# -------------------#
def infer_dtype(fn: Callable) -> Callable:
@functools.wraps(fn)
def _infer_dtype(*args, dtype=None, **kwargs):
"""
Determine the correct `dtype`, and then calls the function with the `dtype`
passed explicitly.
Parameters
----------
args
The arguments to be passed to the function.
dtype
The data type for the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with `dtype` passed explicitly.
"""
# find the first array argument, if required
arr = None if ivy.exists(dtype) else _get_first_array(*args, **kwargs)
# infer the correct data type
dtype = ivy.default_dtype(dtype=dtype, item=arr, as_native=True)
ivy.utils.assertions._check_jax_x64_flag(dtype)
# call the function with dtype provided explicitly
return fn(*args, dtype=dtype, **kwargs)
_infer_dtype.infer_dtype = True
return _infer_dtype
def integer_arrays_to_float(fn: Callable) -> Callable:
@functools.wraps(fn)
def _integer_arrays_to_float(*args, **kwargs):
"""
Promote all the integer array inputs passed to the function both as positional
or keyword arguments to the default float dtype.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with integer array arguments
promoted to default float dtype.
"""
def _to_float_array(x):
if not ivy.is_array(x) or not ivy.is_int_dtype(x.dtype):
return x
if ivy.is_ivy_array(x):
return ivy.asarray(x, dtype=ivy.default_float_dtype())
return ivy.native_array(x, dtype=ivy.default_float_dtype(as_native=True))
args = ivy.nested_map(args, _to_float_array, to_mutable=True)
kwargs = ivy.nested_map(kwargs, _to_float_array, to_mutable=True)
return fn(*args, **kwargs)
_integer_arrays_to_float.integer_arrays_to_float = True
return _integer_arrays_to_float
# Device Handling #
# ----------------#
def infer_device(fn: Callable) -> Callable:
@functools.wraps(fn)
def _infer_device(*args, device=None, **kwargs):
"""
Determine the correct `device`, and then calls the function with the `device`
passed explicitly.
Parameters
----------
args
The arguments to be passed to the function.
device
The device for the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with `device` passed explicitly.
"""
# find the first array argument, if required
arr = None if ivy.exists(device) else _get_first_array(*args, **kwargs)
# infer the correct device
device = ivy.default_device(device, item=arr, as_native=True)
# call the function with device provided explicitly
return fn(*args, device=device, **kwargs)
_infer_device.infer_device = True
return _infer_device
# Inplace Update Handling #
# ------------------------#
def handle_out_argument(fn: Callable) -> Callable:
handle_out_in_backend = hasattr(fn, "support_native_out")
@functools.wraps(fn)
def _handle_out_argument(*args, out=None, **kwargs):
"""
Call `fn` with the `out` argument handled correctly for performing an inplace
update.
Parameters
----------
args
The arguments to be passed to the function.
out
The array to write the result to.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with `out` handled correctly for
inplace updates.
"""
if out is None:
return fn(*args, out=out, **kwargs)
if handle_out_in_backend:
# extract underlying native array for out
native_out = ivy.to_native(out)
# compute return, with backend inplace update handled by
# the backend function
ret = fn(*args, out=native_out, **kwargs)
if isinstance(ret, (tuple, list)):
for i in range(len(ret)):
out[i].data = ivy.to_native(ret[i])
if ivy.backend == "torch":
_update_torch_views(out[i])
else:
out.data = ivy.to_native(ret)
if ivy.backend == "torch":
_update_torch_views(out)
return out
# compute return, and then handle the inplace update explicitly
ret = fn(*args, **kwargs)
if not ivy.is_array(ret) and not ivy.is_ivy_container(ret):
return ivy.nested_multi_map(
lambda x, _: ivy.inplace_update(
x[0], ivy.astype(x[1], ivy.dtype(x[0]))
),
[out, ret],
)
return ivy.inplace_update(out, ivy.astype(ret, ivy.dtype(out)))
# return output matches the dtype of the out array to match numpy and torch
_handle_out_argument.handle_out_argument = True
return _handle_out_argument
def _update_torch_views(x, visited_view=None):
if x._torch_view_refs != []:
_update_torch_references(x, visited_view)
if ivy.exists(x._torch_manipulation):
parent_tensor, fn_args_kwargs = x._torch_manipulation
fn, args, kwargs = fn_args_kwargs
kwargs["copy"] = True
if fn == "rot90":
kwargs = kwargs.copy()
kwargs["k"] = -kwargs["k"]
parent_tensor.data[()] = ivy.__dict__[fn](x, *args, **kwargs).data
else:
parent_tensor.data[()] = ivy.__dict__[fn](x, *args, **kwargs).data
if ivy.exists(x._torch_base):
_update_torch_views(x._torch_base, visited_view=x)
def _update_torch_references(x, visited_view=None):
for ref in x._torch_view_refs:
view = ref()
if ivy.exists(view) and view is not visited_view:
parent_tensor, fn_args_kwargs = view._torch_manipulation
fn, args, kwargs = fn_args_kwargs
kwargs["copy"] = True
view.data[()] = ivy.__dict__[fn](parent_tensor, *args, **kwargs).data
if view._torch_view_refs != []:
_update_torch_references(view)
# Nestable Handling #
# ------------------#
def handle_nestable(fn: Callable) -> Callable:
fn_name = fn.__name__
@functools.wraps(fn)
def _handle_nestable(*args, **kwargs):
"""
Call `fn` with the *nestable* property of the function correctly handled. This
means mapping the function to the container leaves if any containers are passed
in the input.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with the nestable property handled correctly.
"""
# if any of the arguments or keyword arguments passed to the function contains
# a container, get the container's version of the function and call it using
# the passed arguments.
if hasattr(ivy.Container, "_static_" + fn_name):
cont_fn = getattr(ivy.Container, "_static_" + fn_name)
else:
cont_fn = lambda *args, **kwargs: ivy.Container.cont_multi_map_in_function(
fn, *args, **kwargs
)
if ivy.get_nestable_mode() and (
ivy.nested_any(args, ivy.is_ivy_container, check_nests=True)
or ivy.nested_any(kwargs, ivy.is_ivy_container, check_nests=True)
):
return cont_fn(*args, **kwargs)
# if the passed arguments does not contain a container, the function using
# the passed arguments, returning an ivy or a native array.
return fn(*args, **kwargs)
_handle_nestable.handle_nestable = True
return _handle_nestable
# Functions #
def _wrap_function(
key: str, to_wrap: Callable, original: Callable, compositional: bool = False
) -> Callable:
"""
Apply wrapping to backend implementation `to_wrap` if the original implementation
`original` is also wrapped, and if `to_wrap` is not already wrapped. Attributes
`handle_nestable`, `infer_device` etc are set during wrapping, hence indicate to us
whether a certain function has been wrapped or not. Also handles wrapping of the
`linalg` namespace.
Parameters
----------
to_wrap
the new implementation to potentially wrap
original
the original implementation of `to_wrap` which tells us which wrappers we need.
compositional
indicates whether the function being wrapped is compositional
(Default Value = ``False``).
Returns
-------
ret
`to_wrap` appropriately wrapped if `to_wrap` is a function, otherwise just the
input is returned.
"""
if key == "linalg":
for linalg_k, linalg_v in to_wrap.__dict__.items():
if (
isinstance(linalg_v, FunctionType)
and linalg_k.lower() != "namedtuple"
and linalg_k != "with_unsupported_dtypes"
and not linalg_k.startswith("_")
):