Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot_correlations() to plot series and acf/pacf #850

Merged
merged 20 commits into from Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
61e71fd
Add plot_correlations to plot series and acf/pacf
RNKuhns Apr 30, 2021
f9c0815
Merge branch 'main' into corr_plot
mloning May 14, 2021
aafb592
Fixed output of plot_series to handle kwarg ax
RNKuhns May 14, 2021
05bb227
Added input check to plot_correlations
RNKuhns May 14, 2021
85bd24b
Merge branch 'corr_plot' of https://github.com/RNKuhns/sktime into co…
RNKuhns May 14, 2021
41a124a
Updated docstrings to match pydocstyle conventions
RNKuhns May 14, 2021
11d8cea
Updated plot_series docstring default args
RNKuhns May 14, 2021
a5f9d63
Add kwargs to set axes titles to plot_series
RNKuhns May 16, 2021
dad6e16
Added dependency on pytest-mply plot tests
RNKuhns May 16, 2021
b8b7b81
Add baseline images for plot comparison unit tests
RNKuhns May 16, 2021
bbbfb62
Merge pull request #2 from alan-turing-institute/main
RNKuhns May 16, 2021
466abed
Merge branch 'corr_plot' of https://github.com/RNKuhns/sktime into co…
RNKuhns May 16, 2021
d709a1d
Merge branch 'main' into corr_plot
RNKuhns May 20, 2021
611ee31
Updated setup.cfg to exlude tests from pydocstyle
RNKuhns May 20, 2021
13ade37
Added plot_correlations to docs
RNKuhns May 20, 2021
f97f79a
Removed pytest-mpl dependency and tests
RNKuhns Jun 4, 2021
9a9ef2b
Merge branch 'main' into corr_plot
RNKuhns Jun 4, 2021
251dce4
Update plotting unit tests
RNKuhns Jun 5, 2021
dd4d7c1
Fixed plot_correlations docs
RNKuhns Jun 5, 2021
ef64191
Merge branch 'corr_plot' of https://github.com/RNKuhns/sktime into co…
RNKuhns Jun 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_reference/utils.rst
Expand Up @@ -15,6 +15,7 @@ Plotting
:template: function.rst

plot_series
plot_correlations

Data Processing
---------------
Expand Down
156 changes: 145 additions & 11 deletions sktime/utils/plotting.py
@@ -1,26 +1,47 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

__all__ = ["plot_series"]
__author__ = ["Markus Löning"]
"""Common timeseries plotting functionality.

Functions
---------
plot_series(*series, labels=None, markers=None, ax=None)
plot_correlations(
series,
lags=24,
alpha=0.05,
zero_lag=True,
acf_fft=False,
acf_adjusted=True,
pacf_method="ywadjusted",
suptitle=None,
series_title=None,
acf_title="Autocorrelation",
pacf_title="Partial Autocorrelation",
)
"""
__all__ = ["plot_series", "plot_correlations"]
__author__ = ["Markus Löning", "Ryan Kuhns"]

import numpy as np

from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.validation.forecasting import check_y
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf


def plot_series(*series, labels=None, markers=None):
"""Plot one or more time series
def plot_series(
*series, labels=None, markers=None, x_label=None, y_label=None, ax=None
):
"""Plot one or more time series.

Parameters
----------
series : pd.Series
One or more time series
labels : list, optional (default=None)
labels : list, default = None
Names of series, will be displayed in figure legend
markers: list, optional (default=None)
markers: list, default = None
Markers of data points, if None the marker "o" is used by default.
Lenght of list has to match with number of series

Expand All @@ -39,7 +60,7 @@ def plot_series(*series, labels=None, markers=None):
check_y(y)

n_series = len(series)

_ax_kwarg_is_none = True if ax is None else False
# labels
if labels is not None:
if n_series != len(labels):
Expand Down Expand Up @@ -76,8 +97,10 @@ def plot_series(*series, labels=None, markers=None):
# generate integer x-values
xs = [np.argwhere(index.isin(y.index)).ravel() for y in series]

# create figure
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))
# create figure if no Axe provided for plotting
if _ax_kwarg_is_none:
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))

colors = sns.color_palette("colorblind", n_colors=n_series)

# plot series
Expand Down Expand Up @@ -105,7 +128,118 @@ def format_fn(tick_val, tick_pos):
ax.xaxis.set_major_formatter(FuncFormatter(format_fn))
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

# Label the x and y axes
if x_label is not None:
ax.set_xlabel(x_label)

_y_label = y_label if y_label is not None else series[0].name
ax.set_ylabel(_y_label)

if legend:
ax.legend()
if _ax_kwarg_is_none:
return fig, ax
else:
return ax


def plot_correlations(
series,
lags=24,
alpha=0.05,
zero_lag=True,
acf_fft=False,
acf_adjusted=True,
pacf_method="ywadjusted",
suptitle=None,
series_title=None,
acf_title="Autocorrelation",
pacf_title="Partial Autocorrelation",
):
"""Plot series and its ACF and PACF values.

Parameters
----------
series : pd.Series
A time series.

lags : int, default = 24
mloning marked this conversation as resolved.
Show resolved Hide resolved
Number of lags to include in ACF and PACF plots

alpha : int, default = 0.05
Alpha value used to set confidence intervals. Alpha = 0.05 results in
95% confidence interval with standard deviation calculated via
Bartlett's formula.

zero_lag : bool, default = True
If True, start ACF and PACF plots at 0th lag

acf_fft : bool, = False
Whether to compute ACF via FFT.

acf_adjusted : bool, default = True
If True, denonimator of ACF calculations uses n-k instead of n, where
n is number of observations and k is the lag.

pacf_method : str, default = 'ywadjusted'
Method to use in calculation of PACF.

suptitle : str, default = None
The text to use as the Figure's suptitle.

series_title : str, default = None
Used to set the title of the series plot if provided. Otherwise, series
plot has no title.

acf_title : str, default = 'Autocorrelation'
Used to set title of ACF plot.

pacf_title : str, default = 'Partial Autocorrelation'
Used to set title of PACF plot.

Returns
-------
fig : matplotlib.figure.Figure

axes : np.ndarray
Array of the figure's Axe objects
"""
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

return fig, ax
series = check_y(series)

# Setup figure for plotting
fig = plt.figure(constrained_layout=True, figsize=(12, 8))
gs = fig.add_gridspec(2, 2)
f_ax1 = fig.add_subplot(gs[0, :])
if series_title is not None:
f_ax1.set_title(series_title)
f_ax2 = fig.add_subplot(gs[1, 0])
f_ax3 = fig.add_subplot(gs[1, 1])

# Create expected plots on their respective Axes
plot_series(series, ax=f_ax1)
plot_acf(
series,
ax=f_ax2,
lags=lags,
zero=zero_lag,
alpha=alpha,
title=acf_title,
adjusted=acf_adjusted,
fft=acf_fft,
)
plot_pacf(
series,
ax=f_ax3,
lags=lags,
zero=zero_lag,
alpha=alpha,
title=pacf_title,
method=pacf_method,
)
if suptitle is not None:
fig.suptitle(suptitle, size="xx-large")

return fig, np.array(fig.get_axes())
167 changes: 167 additions & 0 deletions sktime/utils/tests/test_plotting.py
@@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
import pytest

from sktime.datasets import load_airline
from sktime.utils.plotting import plot_series, plot_correlations
from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.validation.series import VALID_DATA_TYPES

ALLOW_NUMPY = False
y_airline = load_airline()
y_airline_true = y_airline.iloc[y_airline.index < "1960-01"]
y_airline_test = y_airline.iloc[y_airline.index >= "1960-01"]
series_to_test = [y_airline, (y_airline_true, y_airline_test)]
invalid_input_types = [y_airline.values, pd.DataFrame(y_airline), "this_is_a_string"]


# Need to use _plot_series to make it easy for test cases to pass either a
# single series or a tuple of multiple series to be unpacked as argss
def _plot_series(series, ax=None, **kwargs):
mloning marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(series, tuple):
return plot_series(*series, ax=ax, **kwargs)
else:
return plot_series(series, ax=ax, **kwargs)


@pytest.fixture
def valid_data_types():
valid_data_types = tuple(
filter(
lambda x: x is not np.ndarray and x is not pd.DataFrame, VALID_DATA_TYPES
)
)
return valid_data_types


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_runs_without_error(series_to_plot):
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

_plot_series(series_to_plot)
plt.gcf().canvas.draw_idle()


@pytest.mark.parametrize("series_to_plot", invalid_input_types)
def test_plot_series_invalid_input_type_raises_error(series_to_plot, valid_data_types):
# TODO: Is it possible to dynamically create the matching str if it includes
# characters that need to be escaped (like .)
# match = f"Data must be a one of {valid_data_types}, but found type: {type(Z)}"
with pytest.raises((TypeError, ValueError)):
_plot_series(series_to_plot)


@pytest.mark.parametrize(
"series_to_plot", [(y_airline_true, y_airline_test.reset_index(drop=True))]
)
def test_plot_series_with_unequal_index_type_raises_error(
series_to_plot, valid_data_types
):
match = "Found series with different index types."
with pytest.raises(TypeError, match=match):
_plot_series(series_to_plot)


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_invalid_marker_kwarg_len_raises_error(series_to_plot):
match = """There must be one marker for each time series,
but found inconsistent numbers of series and
markers."""
with pytest.raises(ValueError, match=match):
# Generate error by creating list of markers with length that does
# not match input number of input series
if isinstance(series_to_plot, pd.Series):
markers = ["o", "o"]
elif isinstance(series_to_plot, tuple):
markers = ["o" for _ in range(len(series_to_plot) - 1)]

_plot_series(series_to_plot, markers=markers)


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_invalid_label_kwarg_len_raises_error(series_to_plot):
match = """There must be one label for each time series,
but found inconsistent numbers of series and
labels."""
with pytest.raises(ValueError, match=match):
# Generate error by creating list of labels with length that does
# not match input number of input series
if isinstance(series_to_plot, pd.Series):
labels = ["Series 1", "Series 2"]
elif isinstance(series_to_plot, tuple):
labels = [f"Series {i}" for i in range(len(series_to_plot) - 1)]

_plot_series(series_to_plot, labels=labels)


@pytest.mark.parametrize("series_to_plot", series_to_test)
def test_plot_series_output_type(series_to_plot):
_check_soft_dependencies("matplotlib")
import matplotlib.pyplot as plt

# Test output case where kwarg ax=None
fig, ax = _plot_series(series_to_plot)

is_fig_figure = isinstance(fig, plt.Figure)
is_ax_axis = isinstance(ax, plt.Axes)

assert is_fig_figure and is_ax_axis, "".join(
[
"plot_series with kwarg ax=None should return plt.Figure and plt.Axes,",
f"but returned: {type(fig)} and {type(ax)}",
]
)

# Test output case where an existing plt.Axes object is passed to kwarg ax
fig, ax = plt.subplots(1, figsize=plt.figaspect(0.25))
ax = _plot_series(series_to_plot, ax=ax)

is_ax_axis = isinstance(ax, plt.Axes)

assert is_ax_axis, "".join(
[
"plot_series with plt.Axes object passed to kwarg ax",
f"should return plt.Axes, but returned: {type(ax)}",
]
)


@pytest.mark.parametrize("series_to_plot", [y_airline])
def test_plot_correlations_runs_without_error(series_to_plot):
_check_soft_dependencies("matplotlib")
mloning marked this conversation as resolved.
Show resolved Hide resolved
import matplotlib.pyplot as plt

plot_correlations(series_to_plot)
plt.gcf().canvas.draw_idle()


@pytest.mark.parametrize("series_to_plot", invalid_input_types)
def test_plot_correlations_invalid_input_type_raises_error(
series_to_plot, valid_data_types
):
# TODO: Is it possible to dynamically create the matching str if it includes
# characters that need to be escaped (like .)
# match = f"Data must be a one of {valid_data_types}, but found type: {type(Z)}"
with pytest.raises((TypeError, ValueError)):
plot_correlations(series_to_plot)


@pytest.mark.parametrize("series_to_plot", [y_airline])
def test_plot_correlations_output_type(series_to_plot):
_check_soft_dependencies("matplotlib")
mloning marked this conversation as resolved.
Show resolved Hide resolved
import matplotlib.pyplot as plt

fig, ax = plot_correlations(series_to_plot)

is_fig_figure = isinstance(fig, plt.Figure)
is_ax_array = isinstance(ax, np.ndarray)
is_ax_array_axis = all([isinstance(ax_, plt.Axes) for ax_ in ax])

assert is_fig_figure and is_ax_array and is_ax_array_axis, "".join(
[
"plot_correlations should return plt.Figure and array of plt.Axes,",
f"but returned: {type(fig)} and {type(ax)}",
]
)