Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,9 @@ Period

Plotting
^^^^^^^^
- Bug in :func:`scatter_matrix` where the ``grid`` parameter was ignored (:issue:`50818`)
- Bug in :meth:`Series.plot` when invoked with ``color=None`` (:issue:`51953`)
- Fixed UserWarning in :meth:`DataFrame.plot.scatter` when invoked with ``c="b"`` (:issue:`53908`)
-

Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
13 changes: 9 additions & 4 deletions pandas/plotting/_matplotlib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,16 @@ def scatter_matrix(

ax.set_xlabel(b)
ax.set_ylabel(a)
if i != j:
ax.grid(grid)

if j != 0:
ax.yaxis.set_visible(False)
if i != n - 1:
ax.xaxis.set_visible(False)
if j != 0: # if its not on the left
ax.set_ylabel("")
ax.set_yticklabels([])

if i != n - 1: # if its not on the bottom
ax.set_xlabel("")
ax.set_xticklabels([])

if len(df.columns) > 1:
lim1 = boundaries_list[0]
Expand Down
14 changes: 5 additions & 9 deletions pandas/plotting/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,11 @@ def scatter_matrix(

>>> df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
>>> pd.plotting.scatter_matrix(df, alpha=0.2)
array([[<Axes: xlabel='A', ylabel='A'>, <Axes: xlabel='B', ylabel='A'>,
<Axes: xlabel='C', ylabel='A'>, <Axes: xlabel='D', ylabel='A'>],
[<Axes: xlabel='A', ylabel='B'>, <Axes: xlabel='B', ylabel='B'>,
<Axes: xlabel='C', ylabel='B'>, <Axes: xlabel='D', ylabel='B'>],
[<Axes: xlabel='A', ylabel='C'>, <Axes: xlabel='B', ylabel='C'>,
<Axes: xlabel='C', ylabel='C'>, <Axes: xlabel='D', ylabel='C'>],
[<Axes: xlabel='A', ylabel='D'>, <Axes: xlabel='B', ylabel='D'>,
<Axes: xlabel='C', ylabel='D'>, <Axes: xlabel='D', ylabel='D'>]],
dtype=object)
array([[<Axes: ylabel='A'>, <Axes: >, <Axes: >, <Axes: >],
[<Axes: ylabel='B'>, <Axes: >, <Axes: >, <Axes: >],
[<Axes: ylabel='C'>, <Axes: >, <Axes: >, <Axes: >],
[<Axes: xlabel='A', ylabel='D'>, <Axes: xlabel='B'>,
<Axes: xlabel='C'>, <Axes: xlabel='D'>]], dtype=object)
"""
plot_backend = _get_plot_backend("matplotlib")
return plot_backend.scatter_matrix(
Expand Down