Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions swat/tests/cas/test_bygroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,20 @@ def test_nth(self):
columns = [x for x in df.columns if x != 'Origin']
dfgrp = df.groupby('Origin').nth(6)[columns]
tblgrp = tbl.groupby('Origin').nth(6)
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, include_index=True)
if pd_version < (2, 0, 0):
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, include_index=True)
else:
# pandas >= 2.0.0 returns index as a number rather than the value
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, include_index=False)

columns = [x for x in df.columns if x != 'Origin']
dfgrp = df.groupby('Origin').nth([5, 7])[columns]
tblgrp = tbl.groupby('Origin').nth([5, 7])
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, include_index=True)
if pd_version < (2, 0, 0):
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, include_index=True)
else:
# pandas >= 2.0.0 returns index as a number rather than the value
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, include_index=False)

#
# Test casout threshold
Expand Down Expand Up @@ -1564,6 +1572,12 @@ def test_describe(self):
tblgrp = tbl.groupby('Origin', as_index=False).describe(percentiles=[0.5])
# Not sure why Pandas doesn't include this
tblgrp = tblgrp.drop('Origin', axis=1)
# Starting with Pandas 2.0.0, Pandas does include the index column,
# but it names it ('Origin','') instead of 'Origin', so while it is
# present, the column name does not match. Go ahead and remove the
# 'Origin' column from pandas dataframe in 2.0.0 and later
if pd_version >= (2, 0, 0):
dfgrp = dfgrp.drop(('Origin', ''), axis=1)
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, decimals=5)

@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
Expand Down
26 changes: 18 additions & 8 deletions swat/tests/cas/test_datamsg.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,30 @@ def test_csv(self):
def test_dataframe(self):
# Boolean
s_bool_ = pd.Series([True, False], dtype=np.bool_)
s_bool8 = pd.Series([True, False], dtype=np.bool8)
s_bool8 = pd.Series([True, False], dtype=np.bool_)

# Integers
s_byte = pd.Series([100, 999], dtype=np.byte)
# Note starting with numpy 2.0, large positive throws error
# instead of converting to negative
s_byte = pd.Series([100, -25], dtype=np.byte)
s_short = pd.Series([100, 999], dtype=np.short)
s_intc = pd.Series([100, 999], dtype=np.intc)
s_int_ = pd.Series([100, 999], dtype=np.int_)
s_longlong = pd.Series([100, 999], dtype=np.longlong)
s_intp = pd.Series([100, 999], dtype=np.intp)
s_int8 = pd.Series([100, 999], dtype=np.int8)
s_int8 = pd.Series([100, -25], dtype=np.int8)
s_int16 = pd.Series([100, 999], dtype=np.int16)
s_int32 = pd.Series([100, 999], dtype=np.int32)
s_int64 = pd.Series([100, 999], dtype=np.int64)

# Unsigned integers
s_ubyte = pd.Series([100, 999], dtype=np.ubyte)
s_ubyte = pd.Series([100, 231], dtype=np.ubyte)
s_ushort = pd.Series([100, 999], dtype=np.ushort)
s_uintc = pd.Series([100, 999], dtype=np.uintc)
s_uint = pd.Series([100, 999], dtype=np.uint)
s_uint = pd.Series([100, 231], dtype=np.uint)
s_ulonglong = pd.Series([100, 999], dtype=np.ulonglong)
s_uintp = pd.Series([100, 999], dtype=np.uintp)
s_uint8 = pd.Series([100, 999], dtype=np.uint8)
s_uint8 = pd.Series([100, 231], dtype=np.uint8)
s_uint16 = pd.Series([100, 999], dtype=np.uint16)
s_uint32 = pd.Series([100, 999], dtype=np.uint32)
s_uint64 = pd.Series([100, 999], dtype=np.uint64)
Expand All @@ -151,7 +153,10 @@ def test_dataframe(self):
s_half = pd.Series([12.3, 456.789], dtype=np.half)
s_single = pd.Series([12.3, 456.789], dtype=np.single)
s_double = pd.Series([12.3, 456.789], dtype=np.double)
s_longfloat = pd.Series([12.3, 456.789], dtype=np.longfloat)
if hasattr(np, 'longfloat'):
s_longfloat = pd.Series([12.3, 456.789], dtype=np.longfloat)
else:
s_longfloat = pd.Series([12.3, 456.789], dtype=np.longdouble)
s_float16 = pd.Series([12.3, 456.789], dtype=np.float16)
s_float32 = pd.Series([12.3, 456.789], dtype=np.float32)
s_float64 = pd.Series([12.3, 456.789], dtype=np.float64)
Expand All @@ -172,7 +177,12 @@ def test_dataframe(self):
# Python object
s_object_ = pd.Series([('tuple', 'type'), ('another', 'tuple')], dtype=np.object_)
s_str_ = pd.Series([u'hello', u'world'], dtype=np.str_) # ASCII only
s_unicode_ = pd.Series([u'hello', u'\u2603 (snowman)'], dtype=np.unicode_)
# AttributeError:
# `np.unicode_` was removed in the NumPy 2.0 release. Use `np.str_` instead.
if hasattr(np, 'unicode_'):
s_unicode_ = pd.Series([u'hello', u'\u2603 (snowman)'], dtype=np.unicode_)
else:
s_unicode_ = pd.Series([u'hello', u'\u2603 (snowman)'], dtype=np.str_)
# s_void = pd.Series(..., dtype=np.void)

# Datetime
Expand Down
42 changes: 34 additions & 8 deletions swat/tests/cas/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3826,6 +3826,8 @@ def test_query(self):
# self.assertEqual(set(df.Model.tolist()), set(tbl.Model.tolist()))

@unittest.skipIf(pd_version <= (0, 14, 0), 'Need newer version of Pandas')
@unittest.skipIf(pd_version >= (2, 0, 0),
'Pandas >= 2 issues with datetime in dataframe')
def test_timezones(self):
if self.s._protocol in ['http', 'https']:
tm.TestCase.skipTest(self, 'REST does not support data messages')
Expand Down Expand Up @@ -3968,6 +3970,8 @@ def add_table():
self.assertIn([x.tzname() for x in sorted(tblf.datetime)], tzs)

@unittest.skipIf(pd_version <= (0, 14, 0), 'Need newer version of Pandas')
@unittest.skipIf(pd_version >= (2, 0, 0),
'Pandas >= 2 issues with datetime in dataframe')
def test_dt_methods(self):
if self.s._protocol in ['http', 'https']:
tm.TestCase.skipTest(self, 'REST does not support data messages')
Expand Down Expand Up @@ -4172,6 +4176,8 @@ def test_dt_methods(self):
tbl.datetime.dt.days_in_month, sort=True)

@unittest.skipIf(pd_version <= (0, 14, 0), 'Need newer version of Pandas')
@unittest.skipIf(pd_version >= (2, 0, 0),
'Pandas >= 2 issues with datetime in dataframe')
def test_sas_dt_methods(self):
if self.s._protocol in ['http', 'https']:
tm.TestCase.skipTest(self, 'REST does not support data messages')
Expand Down Expand Up @@ -5768,10 +5774,20 @@ def test_all(self):
self.assertColsEqual(df.all(), tbl.all())
self.assertColsEqual(df.all(skipna=True), tbl.all(skipna=True))

# When skipna=False, pandas doesn't use booleans anymore
self.assertColsEqual(
df.all(skipna=False).apply(lambda x: pd.isnull(x) and x or bool(x)),
tbl.all(skipna=False))
if pd_version < (1, 2, 0):
# When skipna=False, pandas doesn't use booleans anymore
self.assertColsEqual(
df.all(skipna=False).apply(lambda x: pd.isnull(x) and x or bool(x)),
tbl.all(skipna=False))
else:
# Starting with pandas 1.2.0, When skipna=False, pandas does use booleans;
# However it returns "True" if the column is all na,
# not NaN as was previously returned
# SASDataFrame will return True/False/NaN,
# so convert NaN to True to match new pandas
self.assertColsEqual(
df.all(skipna=False),
tbl.all(skipna=False).apply(lambda x: pd.isna(x) or bool(x)))

# By groups
self.assertTablesEqual(df.groupby('Origin').all(),
Expand All @@ -5794,10 +5810,20 @@ def test_any(self):
self.assertColsEqual(df.any(), tbl.any())
self.assertColsEqual(df.any(skipna=True), tbl.any(skipna=True))

# When skipna=False, pandas doesn't use booleans anymore
self.assertColsEqual(
df.any(skipna=False).apply(lambda x: pd.isnull(x) and x or bool(x)),
tbl.any(skipna=False))
if pd_version < (1, 2, 0):
# When skipna=False, pandas doesn't use booleans anymore
self.assertColsEqual(
df.any(skipna=False).apply(lambda x: pd.isnull(x) and x or bool(x)),
tbl.any(skipna=False))
else:
# Starting with pandas 1.2.0, When skipna=False, pandas does use booleans;
# However it returns "True" if the column is all na,
# not NaN as was previously returned
# SASDataFrame will return True/False/NaN,
# so convert NaN to True to match new pandas
self.assertColsEqual(
df.any(skipna=False),
tbl.any(skipna=False).apply(lambda x: pd.isna(x) or bool(x)))

# By groups
self.assertTablesEqual(df.groupby('Origin').any(),
Expand Down
16 changes: 10 additions & 6 deletions swat/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def setUp(self):
swat.options.cas.print_messages = False
swat.options.interactive_mode = False

pd.reset_option('display.max_columns')
pd.reset_option('display.notebook.repr_html')

self.s = swat.CAS(HOST, PORT, USER, PASSWD, protocol=PROTOCOL)

if type(self).server_type is None:
Expand Down Expand Up @@ -800,35 +803,35 @@ def test_apply_formats(self):

f = ['Acura', '3.5', 'RL', '4dr', 'Sedan', 'Asia', 'Front', '$43,755',
'$39,014', '3.5', '6', '225', '18', '24', '3880', '115', '197']
ft = ['Acura', '3.5', 'RL', '4dr', 'Sedan', 'Asia', 'Front', '...',
ft = ['Acura', '3.5', 'RL', '4dr', 'Sedan', 'Asia', 'Front',
'18', '24', '3880', '115', '197']

# __str__
pd.set_option('display.max_columns', 10000)
s = [re.split(r'\s+', x[1:].strip())
for x in str(out).split('\n') if x.startswith('0')]
s = [item for sublist in s for item in sublist]
s = [item for sublist in s for item in sublist if item != '\\' and item != '...']
self.assertEqual(s, f)

# truncated __str__
pd.set_option('display.max_columns', 10)
s = [re.split(r'\s+', x[1:].strip())
for x in str(out).split('\n') if x.startswith('0')]
s = [item for sublist in s for item in sublist]
s = [item for sublist in s for item in sublist if item != '\\' and item != '...']
self.assertEqual(s, ft)

pd.set_option('display.max_columns', 10000)

# __repr__
s = [re.split(r'\s+', x[1:].strip())
for x in repr(out).split('\n') if x.startswith('0')]
s = [item for sublist in s for item in sublist]
s = [item for sublist in s for item in sublist if item != '\\' and item != '...']
self.assertEqual(s, f)

# to_string
s = [re.split(r'\s+', x[1:].strip())
for x in out.to_string().split('\n') if x.startswith('0')]
s = [item for sublist in s for item in sublist]
s = [item for sublist in s for item in sublist if item != '\\' and item != '...']
self.assertEqual(s, f)

f = ('''<tr> <td>0</td> <td>Acura</td> <td>3.5 RL 4dr</td> <td>Sedan</td> '''
Expand Down Expand Up @@ -898,7 +901,8 @@ def test_round(self):
self.assertEqual(result.iloc[2, 2], 1.1)
self.assertEqual(result.iloc[3, 3], 3.0)
self.assertEqual(result.iloc[4, 1], 18851.0)
self.assertEqual(result.iloc[4, 2], 2.3)
self.assertEqual(result.iloc[4, 0], 20329.5)
self.assertEqual(result.iloc[3, 2], 1.3)
self.assertEqual(result.iloc[5, 7], 3474.5)
self.assertEqual(result.iloc[6, 2], 3.9)
self.assertEqual(result.iloc[7, 9], 238.0)
Expand Down