Skip to content

Commit

Permalink
Moves format_kwarg_dictionaries to utils and adds tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Mar 14, 2018
1 parent a04ba03 commit 3f5304d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 51 deletions.
41 changes: 4 additions & 37 deletions mlxtend/plotting/decision_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from itertools import cycle
import matplotlib.pyplot as plt
import numpy as np
from ..utils import check_Xy
from ..utils import check_Xy, format_kwarg_dictionaries
import warnings


Expand Down Expand Up @@ -215,7 +215,7 @@ def plot_decision_regions(X, y, clf,
# Plot decisoin region
# Make sure contourf_kwargs has backwards compatible defaults
contourf_kwargs_default = {'alpha': 0.3, 'antialiased': True}
contourf_kwargs = format_plotting_kwargs(
contourf_kwargs = format_kwarg_dictionaries(
default_kwargs=contourf_kwargs_default,
user_kwargs=contourf_kwargs,
protected_keys=['colors', 'levels'])
Expand All @@ -229,7 +229,7 @@ def plot_decision_regions(X, y, clf,
# Scatter training data samples
# Make sure scatter_kwargs has backwards compatible defaults
scatter_kwargs_default = {'alpha': 0.8, 'edgecolor': 'black'}
scatter_kwargs = format_plotting_kwargs(
scatter_kwargs = format_kwarg_dictionaries(
default_kwargs=scatter_kwargs_default,
user_kwargs=scatter_kwargs,
protected_keys=['c', 'marker', 'label'])
Expand Down Expand Up @@ -288,7 +288,7 @@ def plot_decision_regions(X, y, clf,
'linewidths': 1,
'marker': 'o',
's': 80}
scatter_highlight_kwargs = format_plotting_kwargs(
scatter_highlight_kwargs = format_kwarg_dictionaries(
default_kwargs=scatter_highlight_defaults,
user_kwargs=scatter_highlight_kwargs)
ax.scatter(x_data,
Expand All @@ -304,36 +304,3 @@ def plot_decision_regions(X, y, clf,
framealpha=0.3, scatterpoints=1, loc=legend)

return ax


def format_plotting_kwargs(default_kwargs=None, user_kwargs=None,
protected_keys=None):
"""Function to combine default and user specified plotting kwargs
Parameters
----------
default_kwargs : dict, optional
Default kwargs (default is None).
user_kwargs : dict, optional
User specified kwargs (default is None).
protected_keys : array_like, optional
Sequence of keys to be removed from the returned dictionary
(default is None).
Returns
-------
plotting_dict : dict
Formatted plotting dictionary.
"""
plotting_dict = {}
for d in [default_kwargs, user_kwargs]:
if not isinstance(d, (dict, type(None))):
raise TypeError('d must be of type dict or None, but '
'got {} instead'.format(type(d)))
if d is not None:
plotting_dict.update(d)
if protected_keys is not None:
for key in protected_keys:
plotting_dict.pop(key, None)

return plotting_dict
5 changes: 3 additions & 2 deletions mlxtend/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .counter import Counter
from .testing import assert_raises
from .checking import check_Xy
from .checking import check_Xy, format_kwarg_dictionaries

__all__ = ["Counter", "assert_raises", "check_Xy"]
__all__ = ["Counter", "assert_raises", "check_Xy",
"format_kwarg_dictionaries"]
33 changes: 33 additions & 0 deletions mlxtend/utils/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,36 @@ def check_Xy(X, y, y_int=True):
if y.shape[0] != X.shape[0]:
raise ValueError('y and X must contain the same number of samples. '
'Got y: %d, X: %d' % (y.shape[0], X.shape[0]))


def format_kwarg_dictionaries(default_kwargs=None, user_kwargs=None,
protected_keys=None):
"""Function to combine default and user specified kwargs dictionaries
Parameters
----------
default_kwargs : dict, optional
Default kwargs (default is None).
user_kwargs : dict, optional
User specified kwargs (default is None).
protected_keys : array_like, optional
Sequence of keys to be removed from the returned dictionary
(default is None).
Returns
-------
formatted_kwargs : dict
Formatted kwargs dictionary.
"""
formatted_kwargs = {}
for d in [default_kwargs, user_kwargs]:
if not isinstance(d, (dict, type(None))):
raise TypeError('d must be of type dict or None, but '
'got {} instead'.format(type(d)))
if d is not None:
formatted_kwargs.update(d)
if protected_keys is not None:
for key in protected_keys:
formatted_kwargs.pop(key, None)

return formatted_kwargs
74 changes: 62 additions & 12 deletions mlxtend/utils/tests/test_checking_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
# License: BSD 3 clause

from mlxtend.utils import assert_raises
from mlxtend.utils import check_Xy
from mlxtend.utils import check_Xy, format_kwarg_dictionaries
import numpy as np
import sys
import os

y = np.array([1, 2, 3, 4])
X = np.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])

d_default = {'key1': 1, 'key2': 2}
d_user = {'key3': 3, 'key4': 4}
protected_keys = ['key1', 'key4']

def test_ok():

def test_check_Xy_ok():
check_Xy(X, y)


def test_invalid_type_X():
def test_check_Xy_invalid_type_X():
expect = "X must be a NumPy array. Found <class 'list'>"
if (sys.version_info < (3, 0)):
expect = expect.replace('class', 'type')
Expand All @@ -29,15 +33,15 @@ def test_invalid_type_X():
y)


def test_float16_X():
def test_check_Xy_float16_X():
check_Xy(X.astype(np.float16), y)


def test_float16_y():
def test_check_Xy_float16_y():
check_Xy(X, y.astype(np.int16))


def test_invalid_type_y():
def test_check_Xy_invalid_type_y():
expect = "y must be a NumPy array. Found <class 'list'>"
if (sys.version_info < (3, 0)):
expect = expect.replace('class', 'type')
Expand All @@ -48,15 +52,15 @@ def test_invalid_type_y():
[1, 2, 3, 4])


def test_invalid_dtype_X():
def test_check_Xy_invalid_dtype_X():
assert_raises(ValueError,
'X must be an integer or float array. Found object.',
check_Xy,
X.astype('object'),
y)


def test_invalid_dtype_y():
def test_check_Xy_invalid_dtype_y():

if (sys.version_info > (3, 0)):
expect = ('y must be an integer array. Found <U1. '
Expand All @@ -71,7 +75,7 @@ def test_invalid_dtype_y():
np.array(['a', 'b', 'c', 'd']))


def test_invalid_dim_y():
def test_check_Xy_invalid_dim_y():
if sys.version_info[:2] == (2, 7) and os.name == 'nt':
s = 'y must be a 1D array. Found (4L, 2L)'
else:
Expand All @@ -83,7 +87,7 @@ def test_invalid_dim_y():
X.astype(np.integer))


def test_invalid_dim_X():
def test_check_Xy_invalid_dim_X():
if sys.version_info[:2] == (2, 7) and os.name == 'nt':
s = 'X must be a 2D array. Found (4L,)'
else:
Expand All @@ -95,7 +99,7 @@ def test_invalid_dim_X():
y)


def test_unequal_length_X():
def test_check_Xy_unequal_length_X():
assert_raises(ValueError,
('y and X must contain the same number of samples. '
'Got y: 4, X: 3'),
Expand All @@ -104,10 +108,56 @@ def test_unequal_length_X():
y)


def test_unequal_length_y():
def test_check_Xy_unequal_length_y():
assert_raises(ValueError,
('y and X must contain the same number of samples. '
'Got y: 3, X: 4'),
check_Xy,
X,
y[1:])


def test_format_kwarg_dictionaries_defaults_empty():
empty = format_kwarg_dictionaries()
assert isinstance(empty, dict)
assert len(empty) == 0


def test_format_kwarg_dictionaries_protected_keys():
formatted_kwargs = format_kwarg_dictionaries(
default_kwargs=d_default,
user_kwargs=d_user,
protected_keys=protected_keys)

for key in protected_keys:
assert key not in formatted_kwargs


def test_format_kwarg_dictionaries_no_default_kwargs():
formatted_kwargs = format_kwarg_dictionaries(user_kwargs=d_user)
assert formatted_kwargs == d_user


def test_format_kwarg_dictionaries_no_user_kwargs():
formatted_kwargs = format_kwarg_dictionaries(default_kwargs=d_default)
assert formatted_kwargs == d_default


def test_format_kwarg_dictionaries_default_kwargs_invalid_type():
invalid_kwargs = 'not a dictionary'
message = ('d must be of type dict or None, but got '
'{} instead'.format(type(invalid_kwargs)))
assert_raises(TypeError,
message,
format_kwarg_dictionaries,
default_kwargs=invalid_kwargs)


def test_format_kwarg_dictionaries_user_kwargs_invalid_type():
invalid_kwargs = 'not a dictionary'
message = ('d must be of type dict or None, but got '
'{} instead'.format(type(invalid_kwargs)))
assert_raises(TypeError,
message,
format_kwarg_dictionaries,
user_kwargs=invalid_kwargs)

0 comments on commit 3f5304d

Please sign in to comment.