Skip to content

Commit

Permalink
[ENH] Add colors argument to plot_series (#3908)
Browse files Browse the repository at this point in the history
Add an argument `colors` to `plot_series` to allow the user to pass a list of colors. 

Also adds a small util function to check that the length is correct, and use `from matplotlib.colors import is_color_like` to check that they are all valid colors. If either condition is not met, then it emits a warning and defaults back to the current behavior, which is using `sns.color_palette("colorblind", n_colors=n_series)`.
  • Loading branch information
chillerobscuro committed Dec 19, 2022
1 parent 2129671 commit 9e093f7
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions sktime/utils/plotting.py
Expand Up @@ -7,7 +7,7 @@
__author__ = ["mloning", "RNKuhns", "Drishti Bhasin", "chillerobscuro"]

import math
from warnings import simplefilter
from warnings import simplefilter, warn

import numpy as np
import pandas as pd
Expand All @@ -22,6 +22,7 @@ def plot_series(
*series,
labels=None,
markers=None,
colors=None,
x_label=None,
y_label=None,
ax=None,
Expand All @@ -38,6 +39,8 @@ def plot_series(
markers: list, default = None
Markers of data points, if None the marker "o" is used by default.
The length of the list has to match with the number of series.
colors: list, default = None
The colors to use for plotting each series. Must contain one color per series
pred_interval: pd.DataFrame, default = None
Output of `forecaster.predict_interval()`. Contains columns for lower
and upper boundaries of confidence interval.
Expand Down Expand Up @@ -106,7 +109,9 @@ def plot_series(
if _ax_kwarg_is_none:
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))

colors = sns.color_palette("colorblind", n_colors=n_series)
# colors
if colors is None or not _check_colors(colors, n_series):
colors = sns.color_palette("colorblind", n_colors=n_series)

# plot series
for x, y, color, label, marker in zip(xs, series, colors, labels, markers):
Expand Down Expand Up @@ -350,6 +355,18 @@ def plot_correlations(
return fig, np.array(fig.get_axes())


def _check_colors(colors, n_series):
"""Verify color list is correct length and contains only colors."""
from matplotlib.colors import is_color_like

if n_series == len(colors) and all([is_color_like(c) for c in colors]):
return True
warn(
"Color list must be same length as `series` and contain only matplotlib colors"
)
return False


def _get_windows(cv, y):
"""Generate cv split windows, utility function."""
train_windows = []
Expand Down

0 comments on commit 9e093f7

Please sign in to comment.