Skip to content

Commit

Permalink
Merge b7e8e52 into e5c906f
Browse files Browse the repository at this point in the history
  • Loading branch information
Maitreyee1 committed Nov 28, 2020
2 parents e5c906f + b7e8e52 commit 25f4766
Show file tree
Hide file tree
Showing 7 changed files with 391 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The CHANGELOG for the current development version is available at
- The `bias_variance_decomp` function now supports Keras estimators. ([#725](https://github.com/rasbt/mlxtend/pull/725) via [@hanzigs](https://github.com/hanzigs))
- Adds new `mlxtend.classifier.OneRClassifier` (One Rule Classfier) class, a simple rule-based classifier that is often used as a performance baseline or simple interpretable model. ([#726](https://github.com/rasbt/mlxtend/pull/726)
- Adds new `create_counterfactual` method for creating counterfactuals to explain model predictions. ([#740](https://github.com/rasbt/mlxtend/pull/740))

- Adds new `scatter_hist` method for generating scattered histogram. ([#596](https://github.com/rasbt/mlxtend/issues/596))

##### Changes

Expand Down
285 changes: 285 additions & 0 deletions docs/sources/user_guide/plotting/scatter_hist.ipynb

Large diffs are not rendered by default.

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions mlxtend/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from .ecdf import ecdf
from .scatterplotmatrix import scatterplotmatrix
from .pca_correlation_graph import plot_pca_correlation_graph


from .scatter_hist import scatter_hist
__all__ = ["plot_learning_curves",
"plot_decision_regions",
"plot_confusion_matrix",
Expand All @@ -34,4 +33,5 @@
"checkerboard_plot",
"ecdf",
"scatterplotmatrix",
"plot_pca_correlation_graph"]
"plot_pca_correlation_graph",
"scatter_hist"]
71 changes: 71 additions & 0 deletions mlxtend/plotting/scatter_hist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def scatter_hist(x, y, data):
"""
Scatter plot and individual feature histograms along axes.
Parameters
----------
x : str or int
DataFrame column name of the x-axis values or
integer for the numpy ndarray column index.
y : str or int
DataFrame column name of the y-axis values or
integer for the numpy ndarray column index
data : Pandas DataFrame object or NumPy ndarray.
Returns
---------
plot : matplotlib pyplot figure object
"""
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
spacing = 0.001
rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom + height + spacing, width, 0.2]
rect_histy = [left + width + spacing, bottom, 0.2, height]

if isinstance(data, pd.DataFrame):
for i in (x, y):
assert (isinstance(i, str))
frame = True
xlabel = x
ylabel = y
x = data[x]
y = data[y]

elif isinstance(data, np.ndarray):
for i in (x, y):
assert (isinstance(i, int))
frame = False
x = data[:, x]
y = data[:, y]

else:
raise ValueError('df must be pandas.DataFrame or numpy.ndarray object')

fig = plt.figure(figsize=(5, 5))
ax = fig.add_axes(rect_scatter)
if frame:
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

ax_histx = fig.add_axes(rect_histx, sharex=ax)
ax_histy = fig.add_axes(rect_histy, sharey=ax)
ax_histx.tick_params(axis="x", labelbottom=False)
ax_histy.tick_params(axis="y", labelleft=False)
ax_histx.axis("off")
ax_histy.axis("off")
ax_histx.hist(x, edgecolor='white', bins='auto')
ax_histy.hist(y, edgecolor='white', orientation='horizontal', bins='auto')
plot = ax.scatter(x, y)
return plot
31 changes: 31 additions & 0 deletions mlxtend/plotting/tests/test_scatter_hist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from mlxtend.data import iris_data
from mlxtend.plotting import scatter_hist
import pandas as pd
import pytest

X, y = iris_data()
df = pd.DataFrame(X)
df.columns = (['sepal length [cm]',
'sepal width [cm]',
'petal length [cm]',
'petal width [cm]'])


def test_pass_data_as_dataframe():
scatter_hist("sepal length [cm]", "sepal width [cm]", df)


def test_pass_data_as_numpy_array():
scatter_hist(0, 1, X)


def test_incorrect_x_or_y_data_as_dataframe():
with pytest.raises(AssertionError) as execinfo:
scatter_hist(0, "sepal width [cm]", df)
assert execinfo.value.message == 'Assertion failed'


def test_incorrect_x_or_y_data_as_numpy_array():
with pytest.raises(AssertionError) as execinfo:
scatter_hist("sepal length [cm]", 1, X)
assert execinfo.value.message == 'Assertion failed'

0 comments on commit 25f4766

Please sign in to comment.