-
Notifications
You must be signed in to change notification settings - Fork 849
/
scatter_hist.py
43 lines (36 loc) · 1.11 KB
/
scatter_hist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import seaborn as sns
import pandas as pd
import numpy as np
def scatter_hist(x, y, data):
"""
Scatter plot, individual feature histograms along axes of scatter plot.
Parameters
----------
x : str or int
DataFrame column name of the x-axis values or
integer for the numpy ndarray column index.
y : str
DataFrame column name of the y-axis values or
integer for the numpy ndarray column index
data : Pandas DataFrame object or NumPy ndarray.
Returns
---------
plot : seaborn figure object
"""
if isinstance(data, pd.DataFrame):
for i in (x, y):
assert (isinstance(i, str))
elif isinstance(data, np.ndarray):
for i in (x, y):
assert (isinstance(i, int))
x = data[:, x]
y = data[:, y]
else:
raise ValueError('df must be pandas.DataFrame or numpy.ndarray object')
plot = sns.jointplot(data=data, x=x, y=y)
return plot