Skip to content

Commit

Permalink
Add plot_correlations() to plot series and acf/pacf (#850)
Browse files Browse the repository at this point in the history
* Add plot_correlations to plot series and acf/pacf

* Fixed output of plot_series to handle kwarg ax

* Added input check to plot_correlations

* Updated docstrings to match pydocstyle conventions

* Updated plot_series docstring default args

* Add kwargs to set axes titles to plot_series

* Added dependency on pytest-mply plot tests

* Add baseline images for plot comparison unit tests

* Updated setup.cfg to exlude tests from pydocstyle

* Added plot_correlations to docs

* Removed pytest-mpl dependency and tests

* Update plotting unit tests

* Fixed plot_correlations docs

Co-authored-by: Markus Löning <markus.loning@gmail.com>
  • Loading branch information
RNKuhns and Markus Löning committed Jun 7, 2021
1 parent d606f71 commit ce73505
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/source/api_reference/utils.rst
Expand Up @@ -15,6 +15,7 @@ Plotting
:template: function.rst

plot_series
plot_correlations

Data Processing
---------------
Expand Down
156 changes: 145 additions & 11 deletions sktime/utils/plotting.py
@@ -1,26 +1,47 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

__all__ = ["plot_series"]
__author__ = ["Markus Löning"]
"""Common timeseries plotting functionality.
Functions
---------
plot_series(*series, labels=None, markers=None, ax=None)
plot_correlations(
series,
lags=24,
alpha=0.05,
zero_lag=True,
acf_fft=False,
acf_adjusted=True,
pacf_method="ywadjusted",
suptitle=None,
series_title=None,
acf_title="Autocorrelation",
pacf_title="Partial Autocorrelation",
)
"""
__all__ = ["plot_series", "plot_correlations"]
__author__ = ["Markus Löning", "Ryan Kuhns"]

import numpy as np

from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.validation.forecasting import check_y
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf


def plot_series(*series, labels=None, markers=None):
"""Plot one or more time series
def plot_series(
*series, labels=None, markers=None, x_label=None, y_label=None, ax=None
):
"""Plot one or more time series.
Parameters
----------
series : pd.Series
One or more time series
labels : list, optional (default=None)
labels : list, default = None
Names of series, will be displayed in figure legend
markers: list, optional (default=None)
markers: list, default = None
Markers of data points, if None the marker "o" is used by default.
Lenght of list has to match with number of series
Expand All @@ -39,7 +60,7 @@ def plot_series(*series, labels=None, markers=None):
check_y(y)

n_series = len(series)

_ax_kwarg_is_none = True if ax is None else False
# labels
if labels is not None:
if n_series != len(labels):
Expand Down Expand Up @@ -76,8 +97,10 @@ def plot_series(*series, labels=None, markers=None):
# generate integer x-values
xs = [np.argwhere(index.isin(y.index)).ravel() for y in series]

# create figure
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))
# create figure if no Axe provided for plotting
if _ax_kwarg_is_none:
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))

colors = sns.color_palette("colorblind", n_colors=n_series)

# plot series
Expand Down Expand Up @@ -105,7 +128,118 @@ def format_fn(tick_val, tick_pos):
ax.xaxis.set_major_formatter(FuncFormatter(format_fn))
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# Label the x and y axes
if x_label is not None:
ax.set_xlabel(x_label)

_y_label = y_label if y_label is not None else series[0].name
ax.set_ylabel(_y_label)

if legend:
ax.legend()
if _ax_kwarg_is_none:
return fig, ax
else:
return ax


def plot_correlations(
series,
lags=24,
alpha=0.05,
zero_lag=True,
acf_fft=False,
acf_adjusted=True,
pacf_method="ywadjusted",
suptitle=None,
series_title=None,
acf_title="Autocorrelation",
pacf_title="Partial Autocorrelation",
):
"""Plot series and its ACF and PACF values.
Parameters
----------
series : pd.Series
A time series.
lags : int, default = 24
Number of lags to include in ACF and PACF plots
alpha : int, default = 0.05
Alpha value used to set confidence intervals. Alpha = 0.05 results in
95% confidence interval with standard deviation calculated via
Bartlett's formula.
zero_lag : bool, default = True
If True, start ACF and PACF plots at 0th lag
acf_fft : bool, = False
Whether to compute ACF via FFT.
acf_adjusted : bool, default = True
If True, denonimator of ACF calculations uses n-k instead of n, where
n is number of observations and k is the lag.
pacf_method : str, default = 'ywadjusted'
Method to use in calculation of PACF.
suptitle : str, default = None
The text to use as the Figure's suptitle.
series_title : str, default = None
Used to set the title of the series plot if provided. Otherwise, series
plot has no title.
acf_title : str, default = 'Autocorrelation'
Used to set title of ACF plot.
pacf_title : str, default = 'Partial Autocorrelation'
Used to set title of PACF plot.
Returns
-------
fig : matplotlib.figure.Figure
axes : np.ndarray
Array of the figure's Axe objects
"""
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

return fig, ax
series = check_y(series)

# Setup figure for plotting
fig = plt.figure(constrained_layout=True, figsize=(12, 8))
gs = fig.add_gridspec(2, 2)
f_ax1 = fig.add_subplot(gs[0, :])
if series_title is not None:
f_ax1.set_title(series_title)
f_ax2 = fig.add_subplot(gs[1, 0])
f_ax3 = fig.add_subplot(gs[1, 1])

# Create expected plots on their respective Axes
plot_series(series, ax=f_ax1)
plot_acf(
series,
ax=f_ax2,
lags=lags,
zero=zero_lag,
alpha=alpha,
title=acf_title,
adjusted=acf_adjusted,
fft=acf_fft,
)
plot_pacf(
series,
ax=f_ax3,
lags=lags,
zero=zero_lag,
alpha=alpha,
title=pacf_title,
method=pacf_method,
)
if suptitle is not None:
fig.suptitle(suptitle, size="xx-large")

return fig, np.array(fig.get_axes())
167 changes: 167 additions & 0 deletions sktime/utils/tests/test_plotting.py
@@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
import pytest

from sktime.datasets import load_airline
from sktime.utils.plotting import plot_series, plot_correlations
from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.validation.series import VALID_DATA_TYPES

ALLOW_NUMPY = False
y_airline = load_airline()
y_airline_true = y_airline.iloc[y_airline.index < "1960-01"]
y_airline_test = y_airline.iloc[y_airline.index >= "1960-01"]
series_to_test = [y_airline, (y_airline_true, y_airline_test)]
invalid_input_types = [y_airline.values, pd.DataFrame(y_airline), "this_is_a_string"]


# Need to use _plot_series to make it easy for test cases to pass either a
# single series or a tuple of multiple series to be unpacked as argss
def _plot_series(series, ax=None, **kwargs):
if isinstance(series, tuple):
return plot_series(*series, ax=ax, **kwargs)
else:
return plot_series(series, ax=ax, **kwargs)


@pytest.fixture
def valid_data_types():
valid_data_types = tuple(
filter(
lambda x: x is not np.ndarray and x is not pd.DataFrame, VALID_DATA_TYPES
)
)
return valid_data_types


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_runs_without_error(series_to_plot):
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

_plot_series(series_to_plot)
plt.gcf().canvas.draw_idle()


@pytest.mark.parametrize("series_to_plot", invalid_input_types)
def test_plot_series_invalid_input_type_raises_error(series_to_plot, valid_data_types):
# TODO: Is it possible to dynamically create the matching str if it includes
# characters that need to be escaped (like .)
# match = f"Data must be a one of {valid_data_types}, but found type: {type(Z)}"
with pytest.raises((TypeError, ValueError)):
_plot_series(series_to_plot)


@pytest.mark.parametrize(
"series_to_plot", [(y_airline_true, y_airline_test.reset_index(drop=True))]
)
def test_plot_series_with_unequal_index_type_raises_error(
series_to_plot, valid_data_types
):
match = "Found series with different index types."
with pytest.raises(TypeError, match=match):
_plot_series(series_to_plot)


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_invalid_marker_kwarg_len_raises_error(series_to_plot):
match = """There must be one marker for each time series,
but found inconsistent numbers of series and
markers."""
with pytest.raises(ValueError, match=match):
# Generate error by creating list of markers with length that does
# not match input number of input series
if isinstance(series_to_plot, pd.Series):
markers = ["o", "o"]
elif isinstance(series_to_plot, tuple):
markers = ["o" for _ in range(len(series_to_plot) - 1)]

_plot_series(series_to_plot, markers=markers)


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_invalid_label_kwarg_len_raises_error(series_to_plot):
match = """There must be one label for each time series,
but found inconsistent numbers of series and
labels."""
with pytest.raises(ValueError, match=match):
# Generate error by creating list of labels with length that does
# not match input number of input series
if isinstance(series_to_plot, pd.Series):
labels = ["Series 1", "Series 2"]
elif isinstance(series_to_plot, tuple):
labels = [f"Series {i}" for i in range(len(series_to_plot) - 1)]

_plot_series(series_to_plot, labels=labels)


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_output_type(series_to_plot):
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

# Test output case where kwarg ax=None
fig, ax = _plot_series(series_to_plot)

is_fig_figure = isinstance(fig, plt.Figure)
is_ax_axis = isinstance(ax, plt.Axes)

assert is_fig_figure and is_ax_axis, "".join(
[
"plot_series with kwarg ax=None should return plt.Figure and plt.Axes,",
f"but returned: {type(fig)} and {type(ax)}",
]
)

# Test output case where an existing plt.Axes object is passed to kwarg ax
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))
ax = _plot_series(series_to_plot, ax=ax)

is_ax_axis = isinstance(ax, plt.Axes)

assert is_ax_axis, "".join(
[
"plot_series with plt.Axes object passed to kwarg ax",
f"should return plt.Axes, but returned: {type(ax)}",
]
)


@pytest.mark.parametrize("series_to_plot", [y_airline])
def test_plot_correlations_runs_without_error(series_to_plot):
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

plot_correlations(series_to_plot)
plt.gcf().canvas.draw_idle()


@pytest.mark.parametrize("series_to_plot", invalid_input_types)
def test_plot_correlations_invalid_input_type_raises_error(
series_to_plot, valid_data_types
):
# TODO: Is it possible to dynamically create the matching str if it includes
# characters that need to be escaped (like .)
# match = f"Data must be a one of {valid_data_types}, but found type: {type(Z)}"
with pytest.raises((TypeError, ValueError)):
plot_correlations(series_to_plot)


@pytest.mark.parametrize("series_to_plot", [y_airline])
def test_plot_correlations_output_type(series_to_plot):
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

fig, ax = plot_correlations(series_to_plot)

is_fig_figure = isinstance(fig, plt.Figure)
is_ax_array = isinstance(ax, np.ndarray)
is_ax_array_axis = all([isinstance(ax_, plt.Axes) for ax_ in ax])

assert is_fig_figure and is_ax_array and is_ax_array_axis, "".join(
[
"plot_correlations should return plt.Figure and array of plt.Axes,",
f"but returned: {type(fig)} and {type(ax)}",
]
)

0 comments on commit ce73505

Please sign in to comment.