forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 3
/
ndarray.py
3446 lines (2928 loc) · 113 KB
/
ndarray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=too-many-lines, protected-access
# pylint: disable=import-error, no-name-in-module, undefined-variable
"""NDArray API of MXNet."""
from __future__ import absolute_import
from __future__ import division
try:
from __builtin__ import slice as py_slice
except ImportError:
from builtins import slice as py_slice
from array import array as native_array
import ctypes
import warnings
import operator
from functools import reduce # pylint: disable=redefined-builtin
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call
from ..base import ctypes2buffer
from ..context import Context
from . import _internal
from . import op
from ._internal import NDArrayBase
__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "maximum", "minimum", "moveaxis", "modulo",
"multiply", "not_equal", "onehot_encode", "power", "subtract", "true_divide",
"waitall", "_new_empty_handle"]
_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
_STORAGE_TYPE_ROW_SPARSE = 1
_STORAGE_TYPE_CSR = 2
# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
None: -1,
np.float32: 0,
np.float64: 1,
np.float16: 2,
np.uint8: 3,
np.int32: 4,
np.int8: 5,
np.int64: 6,
}
_DTYPE_MX_TO_NP = {
-1: None,
0: np.float32,
1: np.float64,
2: np.float16,
3: np.uint8,
4: np.int32,
5: np.int8,
6: np.int64,
}
_STORAGE_TYPE_STR_TO_ID = {
'undefined': _STORAGE_TYPE_UNDEFINED,
'default': _STORAGE_TYPE_DEFAULT,
'row_sparse': _STORAGE_TYPE_ROW_SPARSE,
'csr': _STORAGE_TYPE_CSR,
}
_STORAGE_TYPE_ID_TO_STR = {
_STORAGE_TYPE_UNDEFINED: 'undefined',
_STORAGE_TYPE_DEFAULT: 'default',
_STORAGE_TYPE_ROW_SPARSE: 'row_sparse',
_STORAGE_TYPE_CSR: 'csr',
}
_GRAD_REQ_MAP = {
'null': 0,
'write': 1,
'add': 3
}
# pylint: enable= no-member
# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1
def _new_empty_handle():
"""Returns a new empty handle.
Empty handle can be used to hold a result.
Returns
-------
handle
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateNone(ctypes.byref(hdl)))
return hdl
def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
"""Return a new handle with specified shape and context.
Empty handle is only used to hold results.
Returns
-------
handle
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl
def _new_from_shared_mem(shared_pid, shared_id, shape, dtype):
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateFromSharedMem(
ctypes.c_int(shared_pid),
ctypes.c_int(shared_id),
c_array(mx_uint, shape),
mx_uint(len(shape)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl
def waitall():
"""Wait for all async operations to finish in MXNet.
This function is used for benchmarking only.
"""
check_call(_LIB.MXNDArrayWaitAll())
def _storage_type(handle):
storage_type = ctypes.c_int(0)
check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type)))
return storage_type.value
class NDArray(NDArrayBase):
"""An array object representing a multidimensional, homogeneous array of
fixed-size items.
"""
__slots__ = []
# make numpy functions return NDArray instead of numpy object array
__array_priority__ = 1000.0
# Extension type code for TVM function.
# See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h
_tvm_tcode = 19
# pylint: disable= no-member, undefined-variable
@property
def _tvm_handle(self):
return self.handle.value
def __repr__(self):
"""Returns a string representation of the array."""
shape_info = 'x'.join(['%d' % x for x in self.shape])
return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
self.__class__.__name__,
shape_info, self.context)
def __reduce__(self):
return NDArray, (None,), self.__getstate__()
def _to_shared_mem(self):
shared_pid = ctypes.c_int()
shared_id = ctypes.c_int()
check_call(_LIB.MXNDArrayGetSharedMemHandle(
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
return shared_pid.value, shared_id.value, self.shape, self.dtype
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
return add(self, other)
def __iadd__(self, other):
"""x.__iadd__(y) <=> x+=y """
if not self.writable:
raise ValueError('trying to add to a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_add(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._plus_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
"""x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """
return subtract(self, other)
def __isub__(self, other):
"""x.__isub__(y) <=> x-=y """
if not self.writable:
raise ValueError('trying to subtract from a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_sub(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._minus_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other):
"""x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """
return subtract(other, self)
def __mul__(self, other):
"""x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """
return multiply(self, other)
def __neg__(self):
"""x.__neg__(y) <=> -x """
return _internal._mul_scalar(self, -1.0)
def __imul__(self, other):
"""x.__imul__(y) <=> x*=y """
if not self.writable:
raise ValueError('trying to multiply to a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_mul(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._mul_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
return self.__mul__(other)
def __div__(self, other):
"""x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """
return divide(self, other)
def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """
return divide(other, self)
def __idiv__(self, other):
"""x.__rdiv__(y) <=> x/=y """
if not self.writable:
raise ValueError('trying to divide from a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_div(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._div_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
return divide(self, other)
def __rtruediv__(self, other):
return divide(other, self)
def __itruediv__(self, other):
return self.__idiv__(other)
def __mod__(self, other):
"""x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
return modulo(self, other)
def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
return modulo(other, self)
def __imod__(self, other):
"""x.__rmod__(y) <=> x%=y """
if not self.writable:
raise ValueError('trying to take modulo from a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_mod(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._mod_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __pow__(self, other):
"""x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
return power(self, other)
def __rpow__(self, other):
"""x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """
return power(other, self)
def __eq__(self, other):
"""x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """
return equal(self, other)
def __hash__(self):
"""Default hash function."""
return id(self)//16
def __ne__(self, other):
"""x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """
return not_equal(self, other)
def __gt__(self, other):
"""x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """
return greater(self, other)
def __ge__(self, other):
"""x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """
return greater_equal(self, other)
def __lt__(self, other):
"""x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
return lesser(self, other)
def __le__(self, other):
"""x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """
return lesser_equal(self, other)
def __bool__(self):
num_elements = reduce(operator.mul, self.shape, 1)
if num_elements == 0:
return False
elif num_elements == 1:
return bool(self.asscalar())
else:
raise ValueError("The truth value of an NDArray with multiple elements " \
"is ambiguous.")
__nonzero__ = __bool__
def __len__(self):
"""Number of element along the first axis."""
return self.shape[0]
def __getstate__(self):
handle = self.handle
this = {'handle' : None}
if handle is not None:
length = ctypes.c_size_t()
cptr = ctypes.POINTER(ctypes.c_char)()
check_call(_LIB.MXNDArraySaveRawBytes(self.handle,
ctypes.byref(length),
ctypes.byref(cptr)))
this['handle'] = ctypes2buffer(cptr, length.value)
return this
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
buf = handle
handle = NDArrayHandle()
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
length = ctypes.c_size_t(len(buf))
check_call(_LIB.MXNDArrayLoadFromRawBytes(ptr, length, ctypes.byref(handle)))
self.handle = handle
else:
self.handle = None
# pylint: disable=line-too-long
def __setitem__(self, key, value):
"""x.__setitem__(i, y) <=> x[i]=y
Sets value to self[key]. This functions supports advanced indexing defined in the following reference with
some restrictions.
https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
- If key is a list type, only a list of integers is supported, e.g. key=[1, 2] is supported,
while not for key=[[1, 2]].
- Ellipsis (...) and np.newaxis are not supported.
- Boolean array indexing is not supported.
Parameters
----------
key : int, slice, list, np.ndarray, NDArray, or tuple of all previous types
The indexing key.
value : scalar or array-like object that can be broadcast to the shape of self[key]
The value to set.
Examples
--------
>>> x = mx.nd.zeros((2,3))
>>> x[:] = 1
>>> x.asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> x.asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> x[:,1:2] = 2
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 1., 2., 1.]], dtype=float32)
>>> x[1:2,1:] = 3
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 1., 3., 3.]], dtype=float32)
>>> x[1:,0:2] = mx.nd.zeros((1,2))
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 0., 0., 3.]], dtype=float32)
>>> x[1,2] = 4
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 0., 0., 4.]], dtype=float32)
>>> x[[0], [1, 2]] = 5
>>> x.asnumpy()
array([[ 1., 5., 5.],
[ 0., 0., 4.]], dtype=float32)
>>> x[::-1, 0:2:2] = [6]
>>> x.asnumpy()
array([[ 6., 5., 5.],
[ 6., 0., 4.]], dtype=float32)
"""
indexing_dispatch_code = _get_indexing_dispatch_code(key)
if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
self._set_nd_basic_indexing(key, value)
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
self._set_nd_advanced_indexing(key, value)
else:
raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
% (str(key), str(type(key))))
# pylint: enable=line-too-long
# pylint: disable=line-too-long
def __getitem__(self, key):
"""x.__getitem__(i) <=> x[i]
Returns a sliced view of this array if the elements fetched are contiguous in memory;
otherwise, returns a newly created NDArray.
This functions supports advanced indexing defined in the following reference with
some restrictions.
https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
- If key is a list type, only a list of integers is supported, e.g. key=[1, 2] is supported,
while not for key=[[1, 2]].
- Ellipsis (...) and np.newaxis are not supported.
- Boolean array indexing is not supported.
Parameters
----------
key : int, slice, list, np.ndarray, NDArray, or tuple of all previous types
Indexing key.
Examples
--------
>>> x = mx.nd.arange(0,6).reshape((2,3))
>>> x.asnumpy()
array([[ 0., 1., 2.],
[ 3., 4., 5.]], dtype=float32)
>>> x[1].asnumpy()
array([ 3., 4., 5.], dtype=float32)
>>> y = x[0:1]
>>> y[:] = 2
>>> x.asnumpy()
array([[ 2., 2., 2.],
[ 3., 4., 5.]], dtype=float32)
>>> x = mx.nd.arange(0, 8, dtype='int32').reshape((2, 2, 2))
>>> x[[0, 1]]
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
>>> x[1:, [0, 1]]
[[[4 5]
[6 7]]]
>>> y = np.array([0, 1], dtype='int32')
>>> x[1:, y]
[[[4 5]
[6 7]]]
>>> y = mx.nd.array([0, 1], dtype='int32')
>>> x[1:, y]
[[[4 5]
[6 7]]]
"""
indexing_dispatch_code = _get_indexing_dispatch_code(key)
if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
return self._get_nd_basic_indexing(key)
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
return self._get_nd_advanced_indexing(key)
else:
raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
% (str(key), str(type(key))))
# pylint: enable=line-too-long
def _get_index_nd(self, key):
"""Returns an index array for use in scatter_nd and gather_nd."""
def _is_advanced_index(index):
"""The definition of advanced index here includes integers as well, while
integers are considered as basic index type when the key contains only
slices and integers."""
return not isinstance(index, py_slice)
if isinstance(key, (NDArray, np.ndarray, list, integer_types, py_slice)):
key = (key,)
assert isinstance(key, tuple),\
'index=%s must be a NDArray, or np.ndarray, or list, or tuple ' \
' type to use advanced indexing, received type=%s' % (str(key), str(type(key)))
assert len(key) > 0, "Cannot slice with empty indices"
shape = self.shape
assert len(shape) >= len(key),\
"Slicing dimensions exceeds array dimensions, %d vs %d" % (len(key), len(shape))
indices = []
dtype = 'int32' # index data type passed to gather_nd op
need_broadcast = (len(key) != 1)
advanced_indices = [] # include list, NDArray, np.ndarray, integer
basic_indices = [] # include only slices
advanced_index_bshape = None # final advanced index shape
for i, idx_i in enumerate(key):
is_advanced_index = True
if isinstance(idx_i, (np.ndarray, list, tuple)):
idx_i = array(idx_i, ctx=self.context, dtype=dtype)
advanced_indices.append(i)
elif isinstance(idx_i, py_slice):
start, stop, step = _get_index_range(idx_i.start, idx_i.stop, shape[i], idx_i.step)
idx_i = arange(start, stop, step, ctx=self.context, dtype=dtype)
basic_indices.append(i)
is_advanced_index = False
elif isinstance(idx_i, integer_types):
start, stop, step = _get_index_range(idx_i, idx_i+1, shape[i], 1)
idx_i = arange(start, stop, step, ctx=self.context, dtype=dtype)
advanced_indices.append(i)
elif isinstance(idx_i, NDArray):
if dtype != idx_i.dtype:
idx_i = idx_i.astype(dtype)
advanced_indices.append(i)
else:
raise IndexError('Indexing NDArray with index=%s of type=%s is not supported'
% (str(key), str(type(key))))
if is_advanced_index:
if advanced_index_bshape is None:
advanced_index_bshape = idx_i.shape
elif advanced_index_bshape != idx_i.shape:
need_broadcast = True
advanced_index_bshape = _get_broadcast_shape(advanced_index_bshape, idx_i.shape)
indices.append(idx_i)
# Get final index shape for gather_nd. See the following reference
# for determining the output array shape.
# https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing # pylint: disable=line-too-long
if len(advanced_indices) == 0:
raise ValueError('Advanced index tuple must contain at least one of the following types:'
' list, tuple, NDArray, np.ndarray, integer, received index=%s' % key)
# determine the output array's shape by checking whether advanced_indices are all adjacent
# or separated by slices
advanced_indices_adjacent = True
for i in range(0, len(advanced_indices)-1):
if advanced_indices[i] + 1 != advanced_indices[i+1]:
advanced_indices_adjacent = False
break
index_bshape_list = [] # index broadcasted shape
if advanced_indices_adjacent:
for i in range(0, advanced_indices[0]):
index_bshape_list.extend(indices[i].shape)
if not need_broadcast and indices[i].shape != advanced_index_bshape:
need_broadcast = True
index_bshape_list.extend(advanced_index_bshape)
for i in range(advanced_indices[-1]+1, len(indices)):
if not need_broadcast and indices[i].shape != advanced_index_bshape:
need_broadcast = True
index_bshape_list.extend(indices[i].shape)
else:
index_bshape_list.extend(advanced_index_bshape)
for i in basic_indices:
index_bshape_list.extend(indices[i].shape)
if not need_broadcast and indices[i].shape != advanced_index_bshape:
need_broadcast = True
index_bshape = tuple(index_bshape_list)
# Need to broadcast all ndarrays in indices to the final shape.
# For example, suppose an array has shape=(5, 6, 7, 8) and
# key=(slice(1, 5), [[1, 2]], slice(2, 5), [1]).
# Since key[1] and key[3] are two advanced indices here and they are
# separated by basic indices key[0] and key[2], the output shape
# is (1, 2, 4, 3), where the first two elements come from the shape
# that key[1] and key[3] should broadcast to, which is (1, 2), and
# the last two elements come from the shape of two basic indices.
# In order to broadcast all basic and advanced indices to the output shape,
# we need to reshape them based on their axis. For example, to broadcast key[0],
# with shape=(4,), we first need to reshape it into (1, 1, 4, 1), and then
# broadcast the reshaped array to (1, 2, 4, 3); to broadcast key[1], we first
# reshape it into (1, 2, 1, 1), then broadcast the reshaped array to (1, 2, 4, 3).
if need_broadcast:
broadcasted_indices = []
idx_rshape = [1] * len(index_bshape)
if advanced_indices_adjacent:
advanced_index_bshape_start = advanced_indices[0] # start index of advanced_index_bshape in index_shape
advanced_index_bshape_stop = advanced_index_bshape_start + len(advanced_index_bshape)
for i, idx in enumerate(key):
if _is_advanced_index(idx):
k = advanced_index_bshape_stop
# find the reshaped shape for indices[i]
for dim_size in indices[i].shape[::-1]:
k -= 1
idx_rshape[k] = dim_size
else:
if i < advanced_indices[0]: # slice is on the left side of advanced indices
idx_rshape[i] = indices[i].shape[0]
elif i > advanced_indices[-1]: # slice is on the right side of advanced indices
idx_rshape[i-len(key)] = indices[i].shape[0]
else:
raise ValueError('basic index i=%d cannot be between advanced index i=%d and i=%d'
% (i, advanced_indices[0], advanced_indices[-1]))
# broadcast current index to the final shape
broadcasted_indices.append(indices[i].reshape(tuple(idx_rshape)).broadcast_to(index_bshape))
# reset idx_rshape to ones
for j, _ in enumerate(idx_rshape):
idx_rshape[j] = 1
else:
basic_index_offset = len(advanced_index_bshape)
for i, idx in enumerate(key):
if _is_advanced_index(idx):
k = len(advanced_index_bshape)
for dim_size in indices[i].shape[::-1]:
k -= 1
idx_rshape[k] = dim_size
else:
idx_rshape[basic_index_offset] = indices[i].shape[0]
basic_index_offset += 1
# broadcast current index to the final shape
broadcasted_indices.append(indices[i].reshape(tuple(idx_rshape)).broadcast_to(index_bshape))
# reset idx_rshape to ones
for j, _ in enumerate(idx_rshape):
idx_rshape[j] = 1
indices = broadcasted_indices
return op.stack(*indices)
def _prepare_value_nd(self, value, vshape):
"""Given value and vshape, create an `NDArray` from value with the same
context and dtype as the current one and broadcast it to vshape."""
if isinstance(value, numeric_types):
value_nd = full(shape=vshape, val=value, ctx=self.context, dtype=self.dtype)
elif isinstance(value, NDArray):
value_nd = value.as_in_context(self.context)
if value_nd.dtype != self.dtype:
value_nd = value_nd.astype(self.dtype)
else:
try:
value_nd = array(value, ctx=self.context, dtype=self.dtype)
except:
raise TypeError('NDArray does not support assignment with non-array-like'
' object %s of type %s' % (str(value), str(type(value))))
if value_nd.shape != vshape:
value_nd = value_nd.broadcast_to(vshape)
return value_nd
def _set_nd_basic_indexing(self, key, value):
"""This function is called by __setitem__ when key is a basic index, i.e.
an integer, or a slice, or a tuple of integers and slices. No restrictions
on the values of slices' steps."""
shape = self.shape
if isinstance(key, integer_types):
sliced_arr = self._at(key)
sliced_arr[:] = value
return
elif isinstance(key, py_slice):
if key.step is None or key.step == 1: # trivial step
if key.start is not None or key.stop is not None:
sliced_arr = self._slice(key.start, key.stop)
sliced_arr[:] = value
return
# assign value to the whole NDArray
# may need to broadcast first
if isinstance(value, NDArray):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, numeric_types):
_internal._full(shape=shape, ctx=self.context,
dtype=self.dtype, value=float(value), out=self)
elif isinstance(value, (np.ndarray, np.generic)):
if isinstance(value, np.generic) or value.shape != shape:
value = np.broadcast_to(value, shape)
self._sync_copyfrom(value)
else: # value might be a list or a tuple
value_nd = self._prepare_value_nd(value, shape)
value_nd.copyto(self)
return
else: # non-trivial step, use _slice_assign or _slice_assign_scalar
key = (key,)
assert isinstance(key, tuple), "key=%s must be a tuple of slices and integers" % str(key)
assert len(key) <= len(shape), "Indexing dimensions exceed array dimensions, %d vs %d"\
% (len(key), len(shape))
begin = []
end = []
steps = []
oshape = [] # output shape of slice using key
vshape = [] # value shape of data[key]
for i, slice_i in enumerate(key):
dim_size = 1
if isinstance(slice_i, py_slice):
begin.append(slice_i.start)
end.append(slice_i.stop)
steps.append(slice_i.step)
start, stop, step = _get_index_range(slice_i.start, slice_i.stop,
shape[i], slice_i.step)
dim_size = _get_dim_size(start, stop, step)
vshape.append(dim_size)
elif isinstance(slice_i, integer_types):
begin.append(slice_i)
end.append(slice_i+1)
steps.append(1)
else:
raise ValueError("basic indexing does not support index=%s of type=%s"
% (str(slice_i), str(type(slice_i))))
oshape.append(dim_size)
oshape.extend(shape[len(key):])
vshape.extend(shape[len(key):])
# if key contains all integers, vshape should be (1,)
if len(vshape) == 0:
vshape.append(1)
oshape = tuple(oshape)
vshape = tuple(vshape)
if isinstance(value, numeric_types):
_internal._slice_assign_scalar(self, out=self, begin=begin, end=end,
step=steps, scalar=float(value))
else:
value_nd = self._prepare_value_nd(value, vshape)
if vshape != oshape:
value_nd = value_nd.reshape(oshape)
_internal._slice_assign(self, value_nd, begin, end, steps, out=self)
def _set_nd_advanced_indexing(self, key, value):
"""This function is called by __setitem__ when key is an advanced index."""
indices = self._get_index_nd(key)
vshape = _get_oshape_of_gather_nd_op(self.shape, indices.shape)
value_nd = self._prepare_value_nd(value, vshape)
_internal._scatter_set_nd(data=value_nd, indices=indices, shape=self.shape, out=self)
def _get_nd_basic_indexing(self, key):
"""This function is called when key is a slice, or an integer,
or a tuple of slices or integers"""
shape = self.shape
if isinstance(key, integer_types):
if key > shape[0] - 1:
raise IndexError(
'index {} is out of bounds for axis 0 with size {}'.format(
key, shape[0]))
return self._at(key)
elif isinstance(key, py_slice):
if key.step is not None and key.step != 1:
if key.step == 0:
raise ValueError("slice step cannot be zero")
return op.slice(self, begin=(key.start,), end=(key.stop,), step=(key.step,))
elif key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
return self
if not isinstance(key, tuple):
raise ValueError('index=%s must be a slice, or an ineger, or a tuple'
' of slices and integers to use basic indexing, received type=%s'
% (str(key), str(type(key))))
assert len(key) != 0, 'basic index cannot be an empty tuple'
begin = []
end = []
step = []
kept_axes = [] # axes where slice_i is a slice
i = -1
for i, slice_i in enumerate(key):
if isinstance(slice_i, integer_types):
begin.append(slice_i)
end.append(slice_i+1)
step.append(1)
elif isinstance(slice_i, py_slice):
if slice_i.step == 0:
raise ValueError('basic index=%s cannot have slice=%s with step = 0'
% (str(key), str(slice_i)))
begin.append(slice_i.start)
end.append(slice_i.stop)
step.append(slice_i.step)
kept_axes.append(i)
else:
raise ValueError('basic_indexing does not support slicing with '
'index=%s of type=%s.' % (str(slice_i), str(type(slice_i))))
kept_axes.extend(range(i+1, len(shape)))
sliced_nd = op.slice(self, begin, end, step)
if len(kept_axes) == len(shape):
return sliced_nd
# squeeze sliced_shape to remove the axes indexed by integers
oshape = []
sliced_shape = sliced_nd.shape
for axis in kept_axes:
oshape.append(sliced_shape[axis])
# if key is a tuple of integers, still need to keep 1 dim
# while in Numpy, the output will become an value instead of an ndarray
if len(oshape) == 0:
oshape.append(1)
oshape = tuple(oshape)
assert np.prod(oshape) == np.prod(sliced_shape), 'oshape=%s has different size'\
' than sliced_shape=%s'\
% (oshape, sliced_shape)
return sliced_nd.reshape(oshape)
def _get_nd_advanced_indexing(self, key):
"""Get item when key is a tuple of any objects of the following types:
NDArray, np.ndarray, list, tuple, slice, and integer."""
return op.gather_nd(self, self._get_index_nd(key))
def _sync_copyfrom(self, source_array):
"""Performs a synchronized copy from the `source_array` to the current array.
This is called through ``x[:] = source_array``, where the `source_array`
is a `numpy.ndarray` or array-like object.
This function blocks until all the pending read/write operations with respect
to the current `NDArray` are finished and carry out the copy operation to the
current NDArray.
Parameters
----------
source_array : array_like
The data source we would like to copy from.
Example
-------
>>> a = mx.nd.array([1, 2])
>>> a.asnumpy()
array([ 1., 2.], dtype=float32)
>>> a[:] = np.array([3, 4])
>> a.asnumpy()
array([ 3., 4.], dtype=float32)
"""
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must consist of array-like data,' +
'type %s is not supported' % str(type(array)))
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape:
raise ValueError('Shape inconsistent: expected %s vs got %s'%(
str(self.shape), str(source_array.shape)))
check_call(_LIB.MXNDArraySyncCopyFromCPU(
self.handle,
source_array.ctypes.data_as(ctypes.c_void_p),
ctypes.c_size_t(source_array.size)))
def _slice(self, start, stop):
"""Returns a sliced NDArray that shares memory with the current one.
This is called through ``x[start:stop]``.
Parameters
----------
start : int
Starting inclusive index of slice in the first dim.
stop : int
Finishing exclusive index of slice in the first dim.
Returns
-------
`NDArray` sharing the memory with the current one sliced from
start to stop in the first dim.
Examples:
>>> a = mx.nd.array([[1,2], [3, 4], [5, 6], [7, 8]])
>>> a[1:2].asnumpy()
array([[ 3., 4.]], dtype=float32)
>>> a[1:1].asnumpy()
array([], shape=(0, 2), dtype=float32)
"""
handle = NDArrayHandle()
start, stop, _ = _get_index_range(start, stop, self.shape[0])
check_call(_LIB.MXNDArraySlice(
self.handle, mx_uint(start), mx_uint(stop), ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)
def _at(self, idx):
"""Returns a view of the array sliced at `idx` in the first dim.
This is called through ``x[idx]``.
Parameters
----------
idx : int
index for slicing the `NDArray` in the first dim.
Returns
-------
NDArray
`NDArray` sharing the memory with the current one sliced at `idx` in the first dim.
Examples
--------
>>> a = mx.nd.array([[1,2], [3, 4]])
>>> a[1].asnumpy()
array([ 3., 4.], dtype=float32)
>>> b = mx.nd.array([1, 2, 3, 4])
>>> b[0].asnumpy()
array([ 1.], dtype=float32)
"""
handle = NDArrayHandle()
if idx < 0:
length = self.shape[0]
idx += length
if idx < 0:
raise IndexError('index %d is out of bounds for axis 0 with size %d'
% (idx-length, length))
check_call(_LIB.MXNDArrayAt(
self.handle, mx_uint(idx), ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)
def reshape(self, *shape, **kwargs):
"""Returns a **view** of this array with a new shape without altering any data.
Parameters
----------
shape : tuple of int, or n ints
The new shape should not change the array size, namely
``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``.
One dimension can be -1. In this case, the value is inferred
from the length of the array and remaining dimensions.
0 Dimensions in shape will be copied from original shape, i.e.
if x.shape == (3, 4, 5), x.reshape((0, 20)).shape will be (3, 20).
Returns
-------
NDArray
An array with desired shape that shares data with this array.
Examples
--------
>>> x = mx.nd.arange(0,6).reshape((2,3))
>>> x.asnumpy()
array([[ 0., 1., 2.],
[ 3., 4., 5.]], dtype=float32)
>>> y = x.reshape((3,2))
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
>>> y = x.reshape((3,-1))
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
>>> y = x.reshape(3,2)
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
"""
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
shape = shape[0]
elif not shape:
shape = kwargs.get('shape')
assert shape, "Shape must be provided."
if len(kwargs) != 1:
raise TypeError("Only 'shape' is supported as keyword argument. Got: {}."
.format(', '.join(kwargs.keys())))
else:
assert not kwargs,\
"Specifying both positional and keyword arguments is not allowed in reshape."
handle = NDArrayHandle()
# Actual reshape
check_call(_LIB.MXNDArrayReshape(self.handle,
len(shape),
c_array_buf(ctypes.c_int, native_array('i', shape)),
ctypes.byref(handle)))