diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index 9e1973cb8..fae9ab06a 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -23,7 +23,7 @@ The CHANGELOG for the current development version is available at - The `bias_variance_decomp` function now supports Keras estimators. ([#725](https://github.com/rasbt/mlxtend/pull/725) via [@hanzigs](https://github.com/hanzigs)) - Adds new `mlxtend.classifier.OneRClassifier` (One Rule Classfier) class, a simple rule-based classifier that is often used as a performance baseline or simple interpretable model. ([#726](https://github.com/rasbt/mlxtend/pull/726) - Adds new `create_counterfactual` method for creating counterfactuals to explain model predictions. ([#740](https://github.com/rasbt/mlxtend/pull/740)) - +- Adds new `scatter_hist` method for generating scattered histogram. ([#596](https://github.com/rasbt/mlxtend/issues/596)) ##### Changes diff --git a/docs/sources/user_guide/plotting/scatter_hist.ipynb b/docs/sources/user_guide/plotting/scatter_hist.ipynb new file mode 100644 index 000000000..95fa1cdd3 --- /dev/null +++ b/docs/sources/user_guide/plotting/scatter_hist.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scatter Histogram" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A function to quickly produce a scatter histogram plot." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> A function to quickly produce a scatter histogram plot." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### References\n", + "\n", + "- https://matplotlib.org/gallery/lines_bars_and_markers/scatter_hist.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1 - Scatter Plot and Histograms from Pandas DataFrames" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sepal length [cm]sepal width [cm]petal length [cm]petal width [cm]
05.13.51.40.2
14.93.01.40.2
24.73.21.30.2
34.63.11.50.2
45.03.61.40.2
\n", + "
" + ], + "text/plain": [ + " sepal length [cm] sepal width [cm] petal length [cm] petal width [cm]\n", + "0 5.1 3.5 1.4 0.2\n", + "1 4.9 3.0 1.4 0.2\n", + "2 4.7 3.2 1.3 0.2\n", + "3 4.6 3.1 1.5 0.2\n", + "4 5.0 3.6 1.4 0.2" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from mlxtend.data import iris_data\n", + "from mlxtend.plotting import scatter_hist\n", + "import pandas as pd\n", + "X, y = iris_data()\n", + "df = pd.DataFrame(X)\n", + "df.columns = ['sepal length [cm]', 'sepal width [cm]', 'petal length [cm]', 'petal width [cm]']\n", + "df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "from mlxtend.plotting import scatter_hist\n", + "\n", + "fig=scatter_hist(\"sepal length [cm]\", \"sepal width [cm]\", df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plotting the data of two variables with bivariate and univariate graphs. The `x` and `y` values are simply the column names of the DataFrame that we want to plot." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2 - Category Scatter from NumPy Arrays" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[5.1, 3.5, 1.4, 0.2],\n", + " [4.9, 3. , 1.4, 0.2],\n", + " [4.7, 3.2, 1.3, 0.2],\n", + " [4.6, 3.1, 1.5, 0.2],\n", + " [5. , 3.6, 1.4, 0.2]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from mlxtend.data import iris_data\n", + "from mlxtend.plotting import scatter_hist\n", + "import pandas as pd\n", + "X, y = iris_data()\n", + "X[:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, pretending that the first column represents the labels, and the second and third column represent the `x` and `y` values, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig= scatter_hist(0, 1, X)" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + }, + "toc": { + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/docs/sources/user_guide/plotting/scatter_hist_files/scatter_hist_dataframe.png b/docs/sources/user_guide/plotting/scatter_hist_files/scatter_hist_dataframe.png new file mode 100644 index 000000000..84b8ede8c Binary files /dev/null and b/docs/sources/user_guide/plotting/scatter_hist_files/scatter_hist_dataframe.png differ diff --git a/docs/sources/user_guide/plotting/scatter_hist_files/scatter_hist_numpyarray.png b/docs/sources/user_guide/plotting/scatter_hist_files/scatter_hist_numpyarray.png new file mode 100644 index 000000000..906415581 Binary files /dev/null and b/docs/sources/user_guide/plotting/scatter_hist_files/scatter_hist_numpyarray.png differ diff --git a/mlxtend/plotting/__init__.py b/mlxtend/plotting/__init__.py index 200f2b9b5..d07d36a5a 100644 --- a/mlxtend/plotting/__init__.py +++ b/mlxtend/plotting/__init__.py @@ -19,8 +19,7 @@ from .ecdf import ecdf from .scatterplotmatrix import scatterplotmatrix from .pca_correlation_graph import plot_pca_correlation_graph - - +from .scatter_hist import scatter_hist __all__ = ["plot_learning_curves", "plot_decision_regions", "plot_confusion_matrix", @@ -34,4 +33,5 @@ "checkerboard_plot", "ecdf", "scatterplotmatrix", - "plot_pca_correlation_graph"] + "plot_pca_correlation_graph", + "scatter_hist"] diff --git a/mlxtend/plotting/scatter_hist.py b/mlxtend/plotting/scatter_hist.py new file mode 100644 index 000000000..0829d2cf8 --- /dev/null +++ b/mlxtend/plotting/scatter_hist.py @@ -0,0 +1,71 @@ +# Sebastian Raschka 2014-2020 +# mlxtend Machine Learning Library Extensions +# Author: Sebastian Raschka +# +# License: BSD 3 clause +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +def scatter_hist(x, y, data): + """ + Scatter plot and individual feature histograms along axes. + + Parameters + ---------- + x : str or int + DataFrame column name of the x-axis values or + integer for the numpy ndarray column index. + y : str or int + DataFrame column name of the y-axis values or + integer for the numpy ndarray column index + data : Pandas DataFrame object or NumPy ndarray. + + Returns + --------- + plot : matplotlib pyplot figure object + + """ + left, width = 0.1, 0.65 + bottom, height = 0.1, 0.65 + spacing = 0.001 + rect_scatter = [left, bottom, width, height] + rect_histx = [left, bottom + height + spacing, width, 0.2] + rect_histy = [left + width + spacing, bottom, 0.2, height] + + if isinstance(data, pd.DataFrame): + for i in (x, y): + assert (isinstance(i, str)) + frame = True + xlabel = x + ylabel = y + x = data[x] + y = data[y] + + elif isinstance(data, np.ndarray): + for i in (x, y): + assert (isinstance(i, int)) + frame = False + x = data[:, x] + y = data[:, y] + + else: + raise ValueError('df must be pandas.DataFrame or numpy.ndarray object') + + fig = plt.figure(figsize=(5, 5)) + ax = fig.add_axes(rect_scatter) + if frame: + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + ax_histx = fig.add_axes(rect_histx, sharex=ax) + ax_histy = fig.add_axes(rect_histy, sharey=ax) + ax_histx.tick_params(axis="x", labelbottom=False) + ax_histy.tick_params(axis="y", labelleft=False) + ax_histx.axis("off") + ax_histy.axis("off") + ax_histx.hist(x, edgecolor='white', bins='auto') + ax_histy.hist(y, edgecolor='white', orientation='horizontal', bins='auto') + plot = ax.scatter(x, y) + return plot diff --git a/mlxtend/plotting/tests/test_scatter_hist.py b/mlxtend/plotting/tests/test_scatter_hist.py new file mode 100644 index 000000000..dd7f917bb --- /dev/null +++ b/mlxtend/plotting/tests/test_scatter_hist.py @@ -0,0 +1,31 @@ +from mlxtend.data import iris_data +from mlxtend.plotting import scatter_hist +import pandas as pd +import pytest + +X, y = iris_data() +df = pd.DataFrame(X) +df.columns = (['sepal length [cm]', + 'sepal width [cm]', + 'petal length [cm]', + 'petal width [cm]']) + + +def test_pass_data_as_dataframe(): + scatter_hist("sepal length [cm]", "sepal width [cm]", df) + + +def test_pass_data_as_numpy_array(): + scatter_hist(0, 1, X) + + +def test_incorrect_x_or_y_data_as_dataframe(): + with pytest.raises(AssertionError) as execinfo: + scatter_hist(0, "sepal width [cm]", df) + assert execinfo.value.message == 'Assertion failed' + + +def test_incorrect_x_or_y_data_as_numpy_array(): + with pytest.raises(AssertionError) as execinfo: + scatter_hist("sepal length [cm]", 1, X) + assert execinfo.value.message == 'Assertion failed'