forked from numba/numba
-
Notifications
You must be signed in to change notification settings - Fork 2
/
numpy_support.py
533 lines (436 loc) · 17.5 KB
/
numpy_support.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
from __future__ import print_function, division, absolute_import
import collections
import ctypes
import re
import numpy as np
from . import errors, types, config, npdatetime, utils
version = tuple(map(int, np.__version__.split('.')[:2]))
int_divbyzero_returns_zero = config.PYVERSION <= (3, 0)
# Starting from Numpy 1.10, ufuncs accept argument conversion according
# to the "same_kind" rule (used to be "unsafe").
strict_ufunc_typing = version >= (1, 10)
FROM_DTYPE = {
np.dtype('bool'): types.boolean,
np.dtype('int8'): types.int8,
np.dtype('int16'): types.int16,
np.dtype('int32'): types.int32,
np.dtype('int64'): types.int64,
np.dtype('uint8'): types.uint8,
np.dtype('uint16'): types.uint16,
np.dtype('uint32'): types.uint32,
np.dtype('uint64'): types.uint64,
np.dtype('float32'): types.float32,
np.dtype('float64'): types.float64,
np.dtype('complex64'): types.complex64,
np.dtype('complex128'): types.complex128,
}
re_typestr = re.compile(r'[<>=\|]([a-z])(\d+)?$', re.I)
re_datetimestr = re.compile(r'[<>=\|]([mM])8?(\[([a-z]+)\])?$', re.I)
sizeof_unicode_char = np.dtype('U1').itemsize
def _from_str_dtype(dtype):
m = re_typestr.match(dtype.str)
if not m:
raise NotImplementedError(dtype)
groups = m.groups()
typecode = groups[0]
if typecode == 'U':
# unicode
if dtype.byteorder not in '=|':
raise NotImplementedError("Does not support non-native "
"byteorder")
count = dtype.itemsize // sizeof_unicode_char
assert count == int(groups[1]), "Unicode char size mismatch"
return types.UnicodeCharSeq(count)
elif typecode == 'S':
# char
count = dtype.itemsize
assert count == int(groups[1]), "Char size mismatch"
return types.CharSeq(count)
else:
raise NotImplementedError(dtype)
def _from_datetime_dtype(dtype):
m = re_datetimestr.match(dtype.str)
if not m:
raise NotImplementedError(dtype)
groups = m.groups()
typecode = groups[0]
unit = groups[2] or ''
if typecode == 'm':
return types.NPTimedelta(unit)
elif typecode == 'M':
return types.NPDatetime(unit)
else:
raise NotImplementedError(dtype)
def from_dtype(dtype):
"""
Return a Numba Type instance corresponding to the given Numpy *dtype*.
NotImplementedError is raised on unsupported Numpy dtypes.
"""
if dtype.fields is None:
try:
return FROM_DTYPE[dtype]
except KeyError:
if dtype.char in 'SU':
return _from_str_dtype(dtype)
if dtype.char in 'mM':
return _from_datetime_dtype(dtype)
if dtype.char in 'V':
subtype = from_dtype(dtype.subdtype[0])
return types.NestedArray(subtype, dtype.shape)
raise NotImplementedError(dtype)
else:
return from_struct_dtype(dtype)
_as_dtype_letters = {
types.NPDatetime: 'M8',
types.NPTimedelta: 'm8',
types.CharSeq: 'S',
types.UnicodeCharSeq: 'U',
}
def as_dtype(nbtype):
"""
Return a numpy dtype instance corresponding to the given Numba type.
NotImplementedError is if no correspondence is known.
"""
nbtype = types.unliteral(nbtype)
if isinstance(nbtype, (types.Complex, types.Integer, types.Float)):
return np.dtype(str(nbtype))
if nbtype is types.bool_:
return np.dtype('?')
if isinstance(nbtype, (types.NPDatetime, types.NPTimedelta)):
letter = _as_dtype_letters[type(nbtype)]
if nbtype.unit:
return np.dtype('%s[%s]' % (letter, nbtype.unit))
else:
return np.dtype(letter)
if isinstance(nbtype, (types.CharSeq, types.UnicodeCharSeq)):
letter = _as_dtype_letters[type(nbtype)]
return np.dtype('%s%d' % (letter, nbtype.count))
if isinstance(nbtype, types.Record):
return nbtype.dtype
if isinstance(nbtype, types.EnumMember):
return as_dtype(nbtype.dtype)
raise NotImplementedError("%r cannot be represented as a Numpy dtype"
% (nbtype,))
def is_arrayscalar(val):
return np.dtype(type(val)) in FROM_DTYPE
def map_arrayscalar_type(val):
if isinstance(val, np.generic):
# We can't blindly call np.dtype() as it loses information
# on some types, e.g. datetime64 and timedelta64.
dtype = val.dtype
else:
try:
dtype = np.dtype(type(val))
except TypeError:
raise NotImplementedError("no corresponding numpy dtype for %r" % type(val))
return from_dtype(dtype)
def is_array(val):
return isinstance(val, np.ndarray)
def map_layout(val):
if val.flags['C_CONTIGUOUS']:
layout = 'C'
elif val.flags['F_CONTIGUOUS']:
layout = 'F'
else:
layout = 'A'
return layout
def select_array_wrapper(inputs):
"""
Given the array-compatible input types to an operation (e.g. ufunc),
select the appropriate input for wrapping the operation output,
according to each input's __array_priority__.
An index into *inputs* is returned.
"""
max_prio = float('-inf')
selected_input = None
selected_index = None
for index, ty in enumerate(inputs):
# Ties are broken by choosing the first winner, as in Numpy
if isinstance(ty, types.ArrayCompatible) and ty.array_priority > max_prio:
selected_input = ty
selected_index = index
max_prio = ty.array_priority
assert selected_index is not None
return selected_index
def resolve_output_type(context, inputs, formal_output):
"""
Given the array-compatible input types to an operation (e.g. ufunc),
and the operation's formal output type (a types.Array instance),
resolve the actual output type using the typing *context*.
This uses a mechanism compatible with Numpy's __array_priority__ /
__array_wrap__.
"""
selected_input = inputs[select_array_wrapper(inputs)]
args = selected_input, formal_output
sig = context.resolve_function_type('__array_wrap__', args, {})
if sig is None:
if selected_input.array_priority == types.Array.array_priority:
# If it's the same priority as a regular array, assume we
# should return the output unchanged.
# (we can't define __array_wrap__ explicitly for types.Buffer,
# as that would be inherited by most array-compatible objects)
return formal_output
raise errors.TypingError("__array_wrap__ failed for %s" % (args,))
return sig.return_type
def supported_ufunc_loop(ufunc, loop):
"""Return whether the *loop* for the *ufunc* is supported -in nopython-.
*loop* should be a UFuncLoopSpec instance, and *ufunc* a numpy ufunc.
For ufuncs implemented using the ufunc_db, it is supported if the ufunc_db
contains a lowering definition for 'loop' in the 'ufunc' entry.
For other ufuncs, it is type based. The loop will be considered valid if it
only contains the following letter types: '?bBhHiIlLqQfd'. Note this is
legacy and when implementing new ufuncs the ufunc_db should be preferred,
as it allows for a more fine-grained incremental support.
"""
from .targets import ufunc_db
loop_sig = loop.ufunc_sig
try:
# check if the loop has a codegen description in the
# ufunc_db. If so, we can proceed.
# note that as of now not all ufuncs have an entry in the
# ufunc_db
supported_loop = loop_sig in ufunc_db.get_ufunc_info(ufunc)
except KeyError:
# for ufuncs not in ufunc_db, base the decision of whether the
# loop is supported on its types
loop_types = [x.char for x in loop.numpy_inputs + loop.numpy_outputs]
supported_types = '?bBhHiIlLqQfd'
# check if all the types involved in the ufunc loop are
# supported in this mode
supported_loop = all(t in supported_types for t in loop_types)
return supported_loop
class UFuncLoopSpec(collections.namedtuple('_UFuncLoopSpec',
('inputs', 'outputs', 'ufunc_sig'))):
"""
An object describing a ufunc loop's inner types. Properties:
- inputs: the inputs' Numba types
- outputs: the outputs' Numba types
- ufunc_sig: the string representing the ufunc's type signature, in
Numpy format (e.g. "ii->i")
"""
__slots__ = ()
@property
def numpy_inputs(self):
return [as_dtype(x) for x in self.inputs]
@property
def numpy_outputs(self):
return [as_dtype(x) for x in self.outputs]
def ufunc_can_cast(from_, to, has_mixed_inputs, casting='safe'):
"""
A variant of np.can_cast() that can allow casting any integer to
any real or complex type, in case the operation has mixed-kind
inputs.
For example we want `np.power(float32, int32)` to be computed using
SP arithmetic and return `float32`.
However, `np.sqrt(int32)` should use DP arithmetic and return `float64`.
"""
from_ = np.dtype(from_)
to = np.dtype(to)
if has_mixed_inputs and from_.kind in 'iu' and to.kind in 'cf':
# Decide that all integers can cast to any real or complex type.
return True
return np.can_cast(from_, to, casting)
def ufunc_find_matching_loop(ufunc, arg_types):
"""Find the appropriate loop to be used for a ufunc based on the types
of the operands
ufunc - The ufunc we want to check
arg_types - The tuple of arguments to the ufunc, including any
explicit output(s).
return value - A UFuncLoopSpec identifying the loop, or None
if no matching loop is found.
"""
# Separate logical input from explicit output arguments
input_types = arg_types[:ufunc.nin]
output_types = arg_types[ufunc.nin:]
assert(len(input_types) == ufunc.nin)
try:
np_input_types = [as_dtype(x) for x in input_types]
except NotImplementedError:
return None
try:
np_output_types = [as_dtype(x) for x in output_types]
except NotImplementedError:
return None
# Whether the inputs are mixed integer / floating-point
has_mixed_inputs = (
any(dt.kind in 'iu' for dt in np_input_types) and
any(dt.kind in 'cf' for dt in np_input_types))
def choose_types(numba_types, ufunc_letters):
"""
Return a list of Numba types representing *ufunc_letters*,
except when the letter designates a datetime64 or timedelta64,
in which case the type is taken from *numba_types*.
"""
assert len(ufunc_letters) >= len(numba_types)
types = [tp if letter in 'mM' else from_dtype(np.dtype(letter))
for tp, letter in zip(numba_types, ufunc_letters)]
# Add missing types (presumably implicit outputs)
types += [from_dtype(np.dtype(letter))
for letter in ufunc_letters[len(numba_types):]]
return types
# In NumPy, the loops are evaluated from first to last. The first one
# that is viable is the one used. One loop is viable if it is possible
# to cast every input operand to the one expected by the ufunc.
# Also under NumPy 1.10+ the output must be able to be cast back
# to a close enough type ("same_kind").
for candidate in ufunc.types:
ufunc_inputs = candidate[:ufunc.nin]
ufunc_outputs = candidate[-ufunc.nout:]
if 'O' in ufunc_inputs:
# Skip object arrays
continue
found = True
# Skip if any input or output argument is mismatching
for outer, inner in zip(np_input_types, ufunc_inputs):
# (outer is a dtype instance, inner is a type char)
if outer.char in 'mM' or inner in 'mM':
# For datetime64 and timedelta64, we want to retain
# precise typing (i.e. the units); therefore we look for
# an exact match.
if outer.char != inner:
found = False
break
elif not ufunc_can_cast(outer.char, inner,
has_mixed_inputs, 'safe'):
found = False
break
if found and strict_ufunc_typing:
# Can we cast the inner result to the outer result type?
for outer, inner in zip(np_output_types, ufunc_outputs):
if (outer.char not in 'mM' and not
ufunc_can_cast(inner, outer.char,
has_mixed_inputs, 'same_kind')):
found = False
break
if found:
# Found: determine the Numba types for the loop's inputs and
# outputs.
try:
inputs = choose_types(input_types, ufunc_inputs)
outputs = choose_types(output_types, ufunc_outputs)
except NotImplementedError:
# One of the selected dtypes isn't supported by Numba
# (e.g. float16), try other candidates
continue
else:
return UFuncLoopSpec(inputs, outputs, candidate)
return None
def _is_aligned_struct(struct):
return struct.isalignedstruct
def from_struct_dtype(dtype):
if dtype.hasobject:
raise TypeError("Do not support dtype containing object")
fields = {}
for name, info in dtype.fields.items():
# *info* may have 3 element if it has a "title", which can be ignored
[elemdtype, offset] = info[:2]
fields[name] = from_dtype(elemdtype), offset
# Note: dtype.alignment is not consistent.
# It is different after passing into a recarray.
# recarray(N, dtype=mydtype).dtype.alignment != mydtype.alignment
size = dtype.itemsize
aligned = _is_aligned_struct(dtype)
return types.Record(str(dtype.descr), fields, size, aligned, dtype)
def _get_bytes_buffer(ptr, nbytes):
"""
Get a ctypes array of *nbytes* starting at *ptr*.
"""
if isinstance(ptr, ctypes.c_void_p):
ptr = ptr.value
arrty = ctypes.c_byte * nbytes
return arrty.from_address(ptr)
def _get_array_from_ptr(ptr, nbytes, dtype):
return np.frombuffer(_get_bytes_buffer(ptr, nbytes), dtype)
def carray(ptr, shape, dtype=None):
"""
Return a Numpy array view over the data pointed to by *ptr* with the
given *shape*, in C order. If *dtype* is given, it is used as the
array's dtype, otherwise the array's dtype is inferred from *ptr*'s type.
"""
from .typing.ctypes_utils import from_ctypes
try:
# Use ctypes parameter protocol if available
ptr = ptr._as_parameter_
except AttributeError:
pass
# Normalize dtype, to accept e.g. "int64" or np.int64
if dtype is not None:
dtype = np.dtype(dtype)
if isinstance(ptr, ctypes.c_void_p):
if dtype is None:
raise TypeError("explicit dtype required for void* argument")
p = ptr
elif isinstance(ptr, ctypes._Pointer):
ptrty = from_ctypes(ptr.__class__)
assert isinstance(ptrty, types.CPointer)
ptr_dtype = as_dtype(ptrty.dtype)
if dtype is not None and dtype != ptr_dtype:
raise TypeError("mismatching dtype '%s' for pointer %s"
% (dtype, ptr))
dtype = ptr_dtype
p = ctypes.cast(ptr, ctypes.c_void_p)
else:
raise TypeError("expected a ctypes pointer, got %r" % (ptr,))
nbytes = dtype.itemsize * np.product(shape, dtype=np.intp)
return _get_array_from_ptr(p, nbytes, dtype).reshape(shape)
def farray(ptr, shape, dtype=None):
"""
Return a Numpy array view over the data pointed to by *ptr* with the
given *shape*, in Fortran order. If *dtype* is given, it is used as the
array's dtype, otherwise the array's dtype is inferred from *ptr*'s type.
"""
if not isinstance(shape, utils.INT_TYPES):
shape = shape[::-1]
return carray(ptr, shape, dtype).T
def is_contiguous(dims, strides, itemsize):
"""Is the given shape, strides, and itemsize of C layout?
Note: The code is usable as a numba-compiled function
"""
nd = len(dims)
# Check and skip 1s or 0s in inner dims
innerax = nd - 1
while innerax > -1 and dims[innerax] <= 1:
innerax -= 1
# Early exit if all axis are 1s or 0s
if innerax < 0:
return True
# Check itemsize matches innermost stride
if itemsize != strides[innerax]:
return False
# Check and skip 1s or 0s in outer dims
outerax = 0
while outerax < innerax and dims[outerax] <= 1:
outerax += 1
# Check remaining strides to be contiguous
ax = innerax
while ax > outerax:
if strides[ax] * dims[ax] != strides[ax - 1]:
return False
ax -= 1
return True
def is_fortran(dims, strides, itemsize):
"""Is the given shape, strides, and itemsize of F layout?
Note: The code is usable as a numba-compiled function
"""
nd = len(dims)
# Check and skip 1s or 0s in inner dims
firstax = 0
while firstax < nd and dims[firstax] <= 1:
firstax += 1
# Early exit if all axis are 1s or 0s
if firstax >= nd:
return True
# Check itemsize matches innermost stride
if itemsize != strides[firstax]:
return False
# Check and skip 1s or 0s in outer dims
lastax = nd - 1
while lastax > firstax and dims[lastax] <= 1:
lastax -= 1
# Check remaining strides to be contiguous
ax = firstax
while ax < lastax:
if strides[ax] * dims[ax] != strides[ax + 1]:
return False
ax += 1
return True