Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: algos.diff with datetimelike and NaT #37140

Merged
merged 3 commits into from
Oct 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions pandas/_libs/algos.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ ctypedef fused diff_t:
ctypedef fused out_t:
float32_t
float64_t
int64_t


@cython.boundscheck(False)
Expand All @@ -1204,11 +1205,13 @@ def diff_2d(
ndarray[out_t, ndim=2] out,
Py_ssize_t periods,
int axis,
bint datetimelike=False,
):
cdef:
Py_ssize_t i, j, sx, sy, start, stop
bint f_contig = arr.flags.f_contiguous
# bint f_contig = arr.is_f_contig() # TODO(cython 3)
diff_t left, right

# Disable for unsupported dtype combinations,
# see https://github.com/cython/cython/issues/2646
Expand All @@ -1218,6 +1221,9 @@ def diff_2d(
elif (out_t is float64_t
and (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
raise NotImplementedError
elif out_t is int64_t and diff_t is not int64_t:
jreback marked this conversation as resolved.
Show resolved Hide resolved
# We only have out_t of int64_t if we have datetimelike
raise NotImplementedError
else:
# We put this inside an indented else block to avoid cython build
# warnings about unreachable code
Expand All @@ -1231,15 +1237,31 @@ def diff_2d(
start, stop = 0, sx + periods
for j in range(sy):
for i in range(start, stop):
out[i, j] = arr[i, j] - arr[i - periods, j]
left = arr[i, j]
right = arr[i - periods, j]
if out_t is int64_t and datetimelike:
jreback marked this conversation as resolved.
Show resolved Hide resolved
if left == NPY_NAT or right == NPY_NAT:
out[i, j] = NPY_NAT
else:
out[i, j] = left - right
else:
out[i, j] = left - right
else:
if periods >= 0:
start, stop = periods, sy
else:
start, stop = 0, sy + periods
for j in range(start, stop):
for i in range(sx):
out[i, j] = arr[i, j] - arr[i, j - periods]
left = arr[i, j]
jreback marked this conversation as resolved.
Show resolved Hide resolved
right = arr[i, j - periods]
if out_t is int64_t and datetimelike:
if left == NPY_NAT or right == NPY_NAT:
out[i, j] = NPY_NAT
else:
out[i, j] = left - right
else:
out[i, j] = left - right
else:
if axis == 0:
if periods >= 0:
Expand All @@ -1248,15 +1270,31 @@ def diff_2d(
start, stop = 0, sx + periods
for i in range(start, stop):
for j in range(sy):
out[i, j] = arr[i, j] - arr[i - periods, j]
left = arr[i, j]
right = arr[i - periods, j]
if out_t is int64_t and datetimelike:
if left == NPY_NAT or right == NPY_NAT:
out[i, j] = NPY_NAT
else:
out[i, j] = left - right
else:
out[i, j] = left - right
else:
if periods >= 0:
start, stop = periods, sy
else:
start, stop = 0, sy + periods
for i in range(sx):
for j in range(start, stop):
out[i, j] = arr[i, j] - arr[i, j - periods]
left = arr[i, j]
right = arr[i, j - periods]
if out_t is int64_t and datetimelike:
if left == NPY_NAT or right == NPY_NAT:
out[i, j] = NPY_NAT
else:
out[i, j] = left - right
else:
out[i, j] = left - right


# generated from template
Expand Down
18 changes: 15 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,8 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):

if is_extension_array_dtype(dtype):
if hasattr(arr, f"__{op.__name__}__"):
if axis != 0:
raise ValueError(f"cannot diff {type(arr).__name__} on axis={axis}")
jreback marked this conversation as resolved.
Show resolved Hide resolved
return op(arr, arr.shift(n))
else:
warn(
Expand All @@ -1922,18 +1924,26 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
is_timedelta = False
is_bool = False
if needs_i8_conversion(arr.dtype):
dtype = np.float64
dtype = np.int64
arr = arr.view("i8")
na = iNaT
is_timedelta = True

elif is_bool_dtype(dtype):
# We have to cast in order to be able to hold np.nan
dtype = np.object_
is_bool = True

elif is_integer_dtype(dtype):
# We have to cast in order to be able to hold np.nan
dtype = np.float64

orig_ndim = arr.ndim
if orig_ndim == 1:
# reshape so we can always use algos.diff_2d
arr = arr.reshape(-1, 1)
# TODO: require axis == 0

dtype = np.dtype(dtype)
out_arr = np.empty(arr.shape, dtype=dtype)

Expand All @@ -1944,7 +1954,7 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
if arr.ndim == 2 and arr.dtype.name in _diff_special:
# TODO: can diff_2d dtype specialization troubles be fixed by defining
# out_arr inside diff_2d?
algos.diff_2d(arr, out_arr, n, axis)
algos.diff_2d(arr, out_arr, n, axis, datetimelike=is_timedelta)
else:
# To keep mypy happy, _res_indexer is a list while res_indexer is
# a tuple, ditto for lag_indexer.
Expand Down Expand Up @@ -1978,8 +1988,10 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
out_arr[res_indexer] = arr[res_indexer] - arr[lag_indexer]

if is_timedelta:
out_arr = out_arr.astype("int64").view("timedelta64[ns]")
out_arr = out_arr.view("timedelta64[ns]")

if orig_ndim == 1:
out_arr = out_arr[:, 0]
return out_arr


Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,3 +2405,28 @@ def test_index(self):
dtype="timedelta64[ns]",
)
tm.assert_series_equal(algos.mode(idx), exp)


class TestDiff:
@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"])
def test_diff_datetimelike_nat(self, dtype):
# NaT - NaT is NaT, not 0
jreback marked this conversation as resolved.
Show resolved Hide resolved
arr = np.arange(12).astype(np.int64).view(dtype).reshape(3, 4)
arr[:, 2] = arr.dtype.type("NaT", "ns")
result = algos.diff(arr, 1, axis=0)

expected = np.ones(arr.shape, dtype="timedelta64[ns]") * 4
expected[:, 2] = np.timedelta64("NaT", "ns")
expected[0, :] = np.timedelta64("NaT", "ns")

tm.assert_numpy_array_equal(result, expected)

result = algos.diff(arr.T, 1, axis=1)
tm.assert_numpy_array_equal(result, expected.T)

def test_diff_ea_axis(self):
dta = pd.date_range("2016-01-01", periods=3, tz="US/Pacific")._data

msg = "cannot diff DatetimeArray on axis=1"
with pytest.raises(ValueError, match=msg):
algos.diff(dta, 1, axis=1)