Skip to content

Commit

Permalink
feat(plot): improve plotter arguments (#425)
Browse files Browse the repository at this point in the history
This commit adds new arguments to the plotter that allow users to
customize the plots more.
  • Loading branch information
rickstaa committed Mar 8, 2024
1 parent 71efc87 commit c7202a2
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions stable_learning_control/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ def plot_data(
xaxis="Epoch",
value="AverageEpRet",
condition="Condition1",
errorbar="sd",
smooth=1,
font_scale=1.5,
style="darkgrid",
**kwargs,
):
"""Function used to plot data.
Expand All @@ -48,10 +50,14 @@ def plot_data(
off-policy algorithms. The plotter will automatically figure out
which of ``AverageEpRet`` or ``AverageTestEpRet`` to report for
each separate logdir.
condition (str, optional): The condition to search for. By default ``None``.
condition (str, optional): The condition to search for. By default
``Condition1``.
errorbar (str): The error bar you want to use for the plot. Defaults
to ``sd``.
smooth (int): Smooth data by averaging it over a fixed window. This
parameter says how wide the averaging window will be.
font_scale (int): The font scale you want to use for the plot text.
style (str): The style you want to use for the plot.
"""
if smooth > 1:
"""
Expand All @@ -69,8 +75,8 @@ def plot_data(

if isinstance(data, list):
data = pd.concat(data, ignore_index=True)
sns.set(style="darkgrid", font_scale=font_scale)
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar="sd", **kwargs)
sns.set(style=style, font_scale=font_scale)
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar=errorbar, **kwargs)
plt.legend(loc="best").set_draggable(True)

xscale = np.max(np.asarray(data[xaxis])) > 5e3
Expand Down Expand Up @@ -209,6 +215,7 @@ def make_plots(
values=None,
count=False,
font_scale=1.5,
style="darkgrid",
smooth=1,
select=None,
exclude=None,
Expand All @@ -233,7 +240,7 @@ def make_plots(
rules (below).)
xaxis (str): Pick what column from data is used for the x-axis.
Defaults to ``TotalEnvInteracts``.
value (str): Pick what columns from data to graph on the y-axis.
values (list): Pick what columns from data to graph on the y-axis.
Submitting multiple values will produce multiple graphs. Defaults
to ``Performance``, which is not an actual output of any algorithm.
Instead, ``Performance`` refers to either ``AverageEpRet``, the
Expand All @@ -247,6 +254,8 @@ def make_plots(
which is typically a set of identical experiments that only vary
in random seed. But if you'd like to see all of those curves
separately, use the ``--count`` flag.
font_scale (int): The font scale you want to use for the plot text.
style (str): The style you want to use for the plot.
smooth (int): Smooth data by averaging it over a fixed window. This
parameter says how wide the averaging window will be.
select (list[str]): Optional selection rule: the plotter will only show
Expand All @@ -271,6 +280,7 @@ def make_plots(
smooth=smooth,
estimator=estimator,
font_scale=font_scale,
style=style,
)
plt.show()

Expand Down

0 comments on commit c7202a2

Please sign in to comment.