Skip to content

Commit

Permalink
TST: restructure internal extension arrays tests (split between /arra…
Browse files Browse the repository at this point in the history
…ys and /extension) (#22026)
  • Loading branch information
jorisvandenbossche committed Sep 6, 2018
1 parent 46abe18 commit 4612312
Show file tree
Hide file tree
Showing 10 changed files with 413 additions and 313 deletions.

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions pandas/tests/arrays/test_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
import pytest
import numpy as np

from pandas import Index, IntervalIndex, date_range, timedelta_range
from pandas.core.arrays import IntervalArray
import pandas.util.testing as tm


@pytest.fixture(params=[
(Index([0, 2, 4]), Index([1, 3, 5])),
(Index([0., 1., 2.]), Index([1., 2., 3.])),
(timedelta_range('0 days', periods=3),
timedelta_range('1 day', periods=3)),
(date_range('20170101', periods=3), date_range('20170102', periods=3)),
(date_range('20170101', periods=3, tz='US/Eastern'),
date_range('20170102', periods=3, tz='US/Eastern'))],
ids=lambda x: str(x[0].dtype))
def left_right_dtypes(request):
"""
Fixture for building an IntervalArray from various dtypes
"""
return request.param


class TestMethods(object):

@pytest.mark.parametrize('repeats', [0, 1, 5])
def test_repeat(self, left_right_dtypes, repeats):
left, right = left_right_dtypes
result = IntervalArray.from_arrays(left, right).repeat(repeats)
expected = IntervalArray.from_arrays(
left.repeat(repeats), right.repeat(repeats))
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize('bad_repeats, msg', [
(-1, 'negative dimensions are not allowed'),
('foo', r'invalid literal for (int|long)\(\) with base 10')])
def test_repeat_errors(self, bad_repeats, msg):
array = IntervalArray.from_breaks(range(4))
with tm.assert_raises_regex(ValueError, msg):
array.repeat(bad_repeats)

@pytest.mark.parametrize('new_closed', [
'left', 'right', 'both', 'neither'])
def test_set_closed(self, closed, new_closed):
# GH 21670
array = IntervalArray.from_breaks(range(10), closed=closed)
result = array.set_closed(new_closed)
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
tm.assert_extension_array_equal(result, expected)


class TestSetitem(object):

def test_set_na(self, left_right_dtypes):
left, right = left_right_dtypes
result = IntervalArray.from_arrays(left, right)
result[0] = np.nan

expected_left = Index([left._na_value] + list(left[1:]))
expected_right = Index([right._na_value] + list(right[1:]))
expected = IntervalArray.from_arrays(expected_left, expected_right)

tm.assert_extension_array_equal(result, expected)


def test_repr_matches():
idx = IntervalIndex.from_breaks([1, 2, 3])
a = repr(idx)
b = repr(idx.values)
assert a.replace("Index", "Array") == b
9 changes: 5 additions & 4 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@ def test_combine_add(self, data_repeated):
s1 = pd.Series(orig_data1)
s2 = pd.Series(orig_data2)
result = s1.combine(s2, lambda x1, x2: x1 + x2)
expected = pd.Series(
orig_data1._from_sequence([a + b for (a, b) in
zip(list(orig_data1),
list(orig_data2))]))
with np.errstate(over='ignore'):
expected = pd.Series(
orig_data1._from_sequence([a + b for (a, b) in
zip(list(orig_data1),
list(orig_data2))]))
self.assert_series_equal(result, expected)

val = s1.iloc[0]
Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def get_op_from_name(self, op_name):
def check_opname(self, s, op_name, other, exc=NotImplementedError):
op = self.get_op_from_name(op_name)

self._check_op(s, op, other, exc)
self._check_op(s, op, other, op_name, exc)

def _check_op(self, s, op, other, exc=NotImplementedError):
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
result = op(s, other)
expected = s.combine(other, op)
Expand Down Expand Up @@ -69,7 +69,8 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
# ndarray & other series
op_name = all_arithmetic_operators
s = pd.Series(data)
self.check_opname(s, op_name, [s.iloc[0]] * len(s), exc=TypeError)
self.check_opname(s, op_name, pd.Series([s.iloc[0]] * len(s)),
exc=TypeError)

def test_divmod(self, data):
s = pd.Series(data)
Expand Down Expand Up @@ -113,5 +114,5 @@ def test_compare_scalar(self, data, all_compare_operators):
def test_compare_array(self, data, all_compare_operators):
op_name = all_compare_operators
s = pd.Series(data)
other = [0] * len(data)
other = pd.Series([data[0]] * len(data))
self._compare_other(s, data, op_name, other)
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""
This file contains a minimal set of tests for compliance with the extension
array interface test suite, and should contain no other tests.
The test suite for the full functionality of the array is located in
`pandas/tests/arrays/`.
The tests in this file are inherited from the BaseExtensionTests, and only
minimal tweaks should be applied to get the tests passing (by overwriting a
parent method).
Additional tests should either be added to one of the BaseExtensionTests
classes (if they are relevant for the extension interface for all dtypes), or
be added to the array-specific tests in `pandas/tests/arrays/`.
"""
import string

import pytest
Expand Down Expand Up @@ -204,10 +219,14 @@ class TestComparisonOps(base.BaseComparisonOpsTests):
def _compare_other(self, s, data, op_name, other):
op = self.get_op_from_name(op_name)
if op_name == '__eq__':
assert not op(data, other).all()
result = op(s, other)
expected = s.combine(other, lambda x, y: x == y)
assert (result == expected).all()

elif op_name == '__ne__':
assert op(data, other).all()
result = op(s, other)
expected = s.combine(other, lambda x, y: x != y)
assert (result == expected).all()

else:
with pytest.raises(TypeError):
Expand Down
229 changes: 229 additions & 0 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""
This file contains a minimal set of tests for compliance with the extension
array interface test suite, and should contain no other tests.
The test suite for the full functionality of the array is located in
`pandas/tests/arrays/`.
The tests in this file are inherited from the BaseExtensionTests, and only
minimal tweaks should be applied to get the tests passing (by overwriting a
parent method).
Additional tests should either be added to one of the BaseExtensionTests
classes (if they are relevant for the extension interface for all dtypes), or
be added to the array-specific tests in `pandas/tests/arrays/`.
"""
import numpy as np
import pandas as pd
import pytest

from pandas.tests.extension import base
from pandas.core.dtypes.common import is_extension_array_dtype

from pandas.core.arrays import IntegerArray, integer_array
from pandas.core.arrays.integer import (
Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype)


def make_data():
return (list(range(1, 9)) + [np.nan] + list(range(10, 98))
+ [np.nan] + [99, 100])


@pytest.fixture(params=[Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype,
UInt8Dtype, UInt16Dtype, UInt32Dtype, UInt64Dtype])
def dtype(request):
return request.param()


@pytest.fixture
def data(dtype):
return integer_array(make_data(), dtype=dtype)


@pytest.fixture
def data_missing(dtype):
return integer_array([np.nan, 1], dtype=dtype)


@pytest.fixture
def data_repeated(data):
def gen(count):
for _ in range(count):
yield data
yield gen


@pytest.fixture
def data_for_sorting(dtype):
return integer_array([1, 2, 0], dtype=dtype)


@pytest.fixture
def data_missing_for_sorting(dtype):
return integer_array([1, np.nan, 0], dtype=dtype)


@pytest.fixture
def na_cmp():
# we are np.nan
return lambda x, y: np.isnan(x) and np.isnan(y)


@pytest.fixture
def na_value():
return np.nan


@pytest.fixture
def data_for_grouping(dtype):
b = 1
a = 0
c = 2
na = np.nan
return integer_array([b, b, na, na, a, a, b, c], dtype=dtype)


class TestDtype(base.BaseDtypeTests):

@pytest.mark.skip(reason="using multiple dtypes")
def test_is_dtype_unboxes_dtype(self):
# we have multiple dtypes, so skip
pass

def test_array_type_with_arg(self, data, dtype):
assert dtype.construct_array_type() is IntegerArray


class TestArithmeticOps(base.BaseArithmeticOpsTests):

def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
super(TestArithmeticOps, self).check_opname(s, op_name,
other, exc=None)

def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
if s.dtype.is_unsigned_integer and (op_name == '__rsub__'):
# TODO see https://github.com/pandas-dev/pandas/issues/22023
pytest.skip("unsigned subtraction gives negative values")

if (hasattr(other, 'dtype')
and not is_extension_array_dtype(other.dtype)
and pd.api.types.is_integer_dtype(other.dtype)):
# other is np.int64 and would therefore always result in
# upcasting, so keeping other as same numpy_dtype
other = other.astype(s.dtype.numpy_dtype)

result = op(s, other)
expected = s.combine(other, op)

if op_name == '__rdiv__':
# combine is not giving the correct result for this case
pytest.skip("skipping reverse div in python 2")
elif op_name in ('__rtruediv__', '__truediv__', '__div__'):
expected = expected.astype(float)
if op_name == '__rtruediv__':
# TODO reverse operators result in object dtype
result = result.astype(float)
elif op_name.startswith('__r'):
# TODO reverse operators result in object dtype
# see https://github.com/pandas-dev/pandas/issues/22024
expected = expected.astype(s.dtype)
result = result.astype(s.dtype)
else:
# combine method result in 'biggest' (int64) dtype
expected = expected.astype(s.dtype)
pass
if (op_name == '__rpow__') and isinstance(other, pd.Series):
# TODO pow on Int arrays gives different result with NA
# see https://github.com/pandas-dev/pandas/issues/22022
result = result.fillna(1)

self.assert_series_equal(result, expected)
else:
with pytest.raises(exc):
op(s, other)

def _check_divmod_op(self, s, op, other, exc=None):
super(TestArithmeticOps, self)._check_divmod_op(s, op, other, None)

@pytest.mark.skip(reason="intNA does not error on ops")
def test_error(self, data, all_arithmetic_operators):
# other specific errors tested in the integer array specific tests
pass


class TestComparisonOps(base.BaseComparisonOpsTests):

def check_opname(self, s, op_name, other, exc=None):
super(TestComparisonOps, self).check_opname(s, op_name,
other, exc=None)

def _compare_other(self, s, data, op_name, other):
self.check_opname(s, op_name, other)


class TestInterface(base.BaseInterfaceTests):
pass


class TestConstructors(base.BaseConstructorsTests):
pass


class TestReshaping(base.BaseReshapingTests):
pass

# for test_concat_mixed_dtypes test
# concat of an Integer and Int coerces to object dtype
# TODO(jreback) once integrated this would


class TestGetitem(base.BaseGetitemTests):
pass


class TestMissing(base.BaseMissingTests):
pass


class TestMethods(base.BaseMethodsTests):

@pytest.mark.parametrize('dropna', [True, False])
def test_value_counts(self, all_data, dropna):
all_data = all_data[:10]
if dropna:
other = np.array(all_data[~all_data.isna()])
else:
other = all_data

result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
expected = pd.Series(other).value_counts(
dropna=dropna).sort_index()
expected.index = expected.index.astype(all_data.dtype)

self.assert_series_equal(result, expected)


class TestCasting(base.BaseCastingTests):
pass


class TestGroupby(base.BaseGroupbyTests):

@pytest.mark.xfail(reason="groupby not working", strict=True)
def test_groupby_extension_no_sort(self, data_for_grouping):
super(TestGroupby, self).test_groupby_extension_no_sort(
data_for_grouping)

@pytest.mark.parametrize('as_index', [
pytest.param(True,
marks=pytest.mark.xfail(reason="groupby not working",
strict=True)),
False
])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
super(TestGroupby, self).test_groupby_extension_agg(
as_index, data_for_grouping)

0 comments on commit 4612312

Please sign in to comment.