Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC refactoring of PartialDependenceDisplay #25079

Open
glemaitre opened this issue Nov 30, 2022 · 4 comments
Open

RFC refactoring of PartialDependenceDisplay #25079

glemaitre opened this issue Nov 30, 2022 · 4 comments

Comments

@glemaitre
Copy link
Member

glemaitre commented Nov 30, 2022

I would like to discuss the possibility to refactor PartialDependenceDisplay to reduce its usability. The idea behind this refactoring is:

  • make it consistent with other displays: calling PartialDependenceDisplay is equivalent to a call to partial_dependence.
  • it will simplify greatly the code base: the codebase was already refactored but this is still difficult to understand logic where axes are interacting with each other.
  • keep the figure close to standard matplotlib: it allows a user that knows matplotlib to quickly modify and tune the figure without knowing some implementation details enforced by scikit-learn.

Sharing y-axis

One main reason for having PartialDependenceDisplay doing several plots is mainly for setting a common y-axis. However, matplotlib allows such a feature when creating a figure:

features = ["age", "sex"]
fig, axs = plt.subplots(figsize=(12, 4), ncols=len(features), sharey=True)
for feat, ax in zip(features, axs):
    pd_values = partial_dependence(
        model,
        X,
        features=feat,
        categorical_features=categorical_features,
    )
    if feat in categorical_features:
        ax.bar(
                pd_values["values"][0], pd_values["average"][0]
            )
    else:
        ax.plot(
            pd_values["values"][0], pd_values["average"][0]
        )

image

Here, the call could be done with PartialDependenceDisplay instead of the call to partial_dependence + plotting.

In this case, a user can still manipulate the difference axes using axs and make any usual matplotlib.

Surprising API

With the current PartialDependenceDisplay, we will have some parameters specific to our settings: ncols to set the number of columns and fitted attribute bounding_axis_ and axes_. The latest attributes are required if someone is giving an axis, for instance, to provide a certain figure size:

fig, ax = plt.subplots(figsize=(12, 4))
disp = PartialDependenceDisplay.from_estimator(
    ...
    ax=ax,
)

I personally find it more intuitive to have ax containing all the axes without picking up in the scikit-learn documentation, which way the plot is organized.

Cons

The only thing that could be more difficult is to plot 1D and 2D PD at the same time. Since one wants to share y-axis for 1D plot, this is indeed not a desired feature to have for the 2D PD. Therefore, it requires our user to know how to share y-axis only for some of the axis.

At the same time, it could be argued that having a mixed type of 1D and 2D plots on the same figure should not be a feature :) Those should be on separate figures.

Proposed API

The new API would force the user to materialize a for loop to handle each individual plot as:

features = ["age", "sex"]
fig, axs = plt.subplots(figsize=(12, 4), ncols=len(features), sharey=True)
for feat, ax in zip(features, axs):
    PartialDependenceDisplay.from_estimator(
        model,
        X,
        feature=feat,
        categorical_features=categorical_features,
        ax=ax,
    )
@glemaitre glemaitre added the RFC label Nov 30, 2022
@ogrisel
Copy link
Member

ogrisel commented Dec 2, 2022

Can you please update the description of the issue to give more details on the proposed new API on the same example as the one above with partial_dependence + manual matplotlib?

I am not sure I understand the proposal. Is the goal to make PartialDependenceDisplay compute and display one PD plot at a time?

Would you suggest introducing another class for dedicated for 2-way PD plots and reserve the PartialDependenceDisplay for the one-feature-at-a-time PD-plots?

At the same time, it could be argued that having a mixed type of 1D and 2D plots on the same figure should not be a feature :) Those should be on separate figure.

I agree.

@glemaitre
Copy link
Member Author

I am not sure I understand the proposal. Is the goal to make PartialDependenceDisplay compute and display one PD plot at a time?

Yes, this is exactly my proposal.

Would you suggest introducing another class for dedicated for 2-way PD plots and reserve the PartialDependenceDisplay for the one-feature-at-a-time PD-plots?

I think this is fine to have a single class. partial_dependence allows to compute interaction and it should be possible to do so with PartialDependenceDisplay. Regarding maintenance of the code, we already have 2 branches in the code with one private function to plot 1D PDP and one function to plot 2D PDP. This part of the code is not too much of a hassle to understand.

@NicolasHug
Copy link
Member

Thanks @glemaitre. I remember thinking the PDP code-based was fairly complex and simplifying it would be a worthwhile investment. I think a lot of the complexity comes from the fact that its design pre-dates our Display API. So any decision that was made back then is worth re-visiting.

At the same time, it could be argued that having a mixed type of 1D and 2D plots on the same figure should not be a feature :) Those should be on separate figure.

I agree.

Did you mean ax instead of figure? In our examples we plot 1D and 2D PDPs on the same figure and they seem fine?

Regarding y-axis sharing: is this really a built-in feature of the current PartialDependenceDisplay? I can't see any instances of sharey=True. From what I understand y can always be "shared" implicitly by passing centered=True; and so y-sharing would still be supported by the new API - on top of mixing 1D and 2D plots on the same figure.

@NicolasHug
Copy link
Member

The only caveat I'd note about a deprecation is that the PDPs utilities have gone through at least 2 major deprecation cycles in the past years: in 0.21 we moved them from ensemble to inspection, and in 1.0 we introduced the Display API and deprecated the plain plot_partial_dependence function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants