In [None]:
def get_kwargs(df_, y, hue="ShortLabel", errs=True):
    kwargs = dict(
        hue=hue,
        style=hue,
        markers=True,
        dashes=False,
    )
    if errs:
        kwargs.update(dict(
            err_kws=dict(capsize=10),
            linewidth=3,
            err_style="bars",
        ))
    if hue == "ShortLabel":
        if all(k in maps.SHORT_COLORMAP for k in df_[hue]):
            kwargs["palette"] = maps.SHORT_COLORMAP
            kwargs["hue_order"] = [
                k for k in maps.SHORT_COLORMAP.keys() if k in df_["ShortLabel"].values
            ]
        else:
            print("Missing - not using SHORT_COLORMAP")
            print(set([k for k in df_[hue] if k not in maps.SHORT_COLORMAP]))
        if all(k in maps.SHORT_MARKERMAP for k in df_[hue]):
            kwargs["markers"] = maps.SHORT_MARKERMAP
        else:
            print("Missing - not using SHORT_MARKERMAP")
            print(set([k for k in df_[hue] if k not in maps.SHORT_MARKERMAP]))
    else:
        if all(k in maps.DEFAULT_COLORMAP for k in df_[hue]):
            kwargs["palette"] = maps.DEFAULT_COLORMAP
            kwargs["hue_order"] = [
                k for k in maps.DEFAULT_COLORMAP.keys() if k in df_["PlotLabel"].values
            ]
    return kwargs


def postplot(df_, target=True, target_v=None, targetlabel=True, figlabel=True):
    ax = plt.gca()

    if ax.get_legend():
        ax.legend()
        # g.get_legend().set_title(None)

    ax.set_ylim(0, None)
    ax.set_xlim(0, None)
    y = ax.get_ylabel()
    y = nice_ylabel(y)
    ax.set_ylabel(y, loc="top")
    fig_labels = []
    if "SampleRatio" in df_.columns:
        if df_["SampleRatio"].nunique() == 1:
            fig_labels.append(f"{df_['SampleRatio'].unique()[0]:g}%")
        else:
            print("Multiple SampleRatio:", df_["SampleRatio"].unique())
    if "RegionLabel" in df_.columns:
        if df_["RegionLabel"].nunique() == 1:
            fig_labels.append(str(df_["RegionLabel"].unique()[0]))
    if fig_labels and figlabel:
        add_fig_label(", ".join(fig_labels))
    if target:
        ax.axvline(target_v, ls=":", c="black")
    if targetlabel:
        maps.add_target_label(twr=target_v, fmt="3 DWPD")
        
        
import matplotlib.ticker as ticker


def add_leg_to_subplot(loc=(2, 4, 4)):
    ax_0 = plt.subplot(loc[0], loc[1], 1)
    handles, labels = ax_0.get_legend_handles_labels()
    ax_0.get_legend().remove()
    ax = plt.subplot(*loc)
    ax.legend(handles, labels, loc="center", title="Policy")
    ax.set_axis_off()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

def postsubplot_wr(ax, i):
    if i == 0:
        ax.xaxis.set_major_locator(ticker.MaxNLocator(3))
        ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())
    ax.yaxis.set_major_locator(ticker.MaxNLocator(3))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(5))
    ax.tick_params(which="major", length=6)
    ax.tick_params(which="minor", length=4)
    ax.set_ylabel("")
    ax.set_xlabel("")
    
reload(maps)
contexts.use("single")

def plot_wrs_grid(df=None, y="P100ServiceTimeUtil@10m", hue="ShortLabel", x="Target DWPD"):
    num_traces = df["RegionLabel"].nunique()

    fig, ax = plt.subplots(
        nrows=2,
        ncols=4,
        sharex=True,
        sharey=False,
        figsize=(7 * 2, 3 * 2),
        layout="constrained",
    )
    for i, (region, df_) in enumerate(df.groupby("RegionLabel")):
        ax = plt.subplot(2, 4, i + 1 + (1 if i > 2 else 0))
        sns.lineplot(
            data=df_, x=x, y=y, **get_kwargs(df_, y, hue=hue), ax=ax, legend=i == 0
        )
        postplot(df_, target=False, targetlabel=False)
        postsubplot_wr(ax, i)
    add_leg_to_subplot((2,4,4))
    if "Write Rate" in x:
        fig.supxlabel(maps.l_wr)
    elif "DWPD" in x:
        fig.supxlabel("DWPD (Drive Writes Per Day)")
    fig.supylabel(nice_ylabel(y))