Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Native support for (XGBoost) Categoricals #3813

Closed
2 of 4 tasks
mattharrison opened this issue Aug 13, 2024 · 3 comments
Closed
2 of 4 tasks

BUG: Native support for (XGBoost) Categoricals #3813

mattharrison opened this issue Aug 13, 2024 · 3 comments
Labels
bug Indicates an unexpected problem or unintended behaviour visualization Relating to plotting

Comments

@mattharrison
Copy link

Issue Description

I can find examples of scatter plots that show categoricals yet they are label encoded and use the display_features of shap.dependence_plot to simulate categories.

When I create models with categories (with XGBoost or CatBoost), I use 'category' types for the columns.

This fails if I try to create a scatter plot and view the impact of the category.

Minimal Reproducible Example

import shap
import xgboost

cal_X, cal_y = shap.datasets.adult(n_points=1000, display=True)

xg_cal = xgboost.XGBClassifier(enable_categorical=True)
xg_cal.fit(cal_X, cal_y)

ex_cal = shap.TreeExplainer(xg_cal)
vals_cal = ex_cal(cal_X)    

shap.plots.scatter(vals_cal[:, 'Relationship'])

Traceback

{
	"name": "TypeError",
	"message": "unsupported operand type(s) for -: 'str' and 'str'",
	"stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[69], line 1
----> 1 shap.plots.scatter(vals_cal[:, 'Relationship'])

File ~/.envs/menv/lib/python3.10/site-packages/shap/plots/_scatter.py:194, in scatter(shap_values, color, hist, axis_color, cmap, dot_size, x_jitter, alpha, title, xmin, xmax, ymin, ymax, overlay, ax, ylabel, show)
    192 min_dist = np.inf
    193 for i in range(1,len(vals)):
--> 194     d = vals[i] - vals[i-1]
    195     if d > 1e-8 and d < min_dist:
    196         min_dist = d

TypeError: unsupported operand type(s) for -: 'str' and 'str'"
}

Expected Behavior

I would love to see a scatterplot like the dependence_plot examples on this page: https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Census%20income%20classification%20with%20XGBoost.html

Bug report checklist

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest release of shap.
  • I have confirmed this bug exists on the master branch of shap.
  • I'd be interested in making a PR to fix this bug

Installed Versions

0.46.0

@mattharrison mattharrison added the bug Indicates an unexpected problem or unintended behaviour label Aug 13, 2024
@mattharrison
Copy link
Author

Posting my workaround for when I search for this in the future others.

import seaborn as sns

makes = ['Ford', 'Toyota', 'Honda', 'Tesla']

(pd.DataFrame(vals.values, columns=X_reg.columns)
     .rename(columns=lambda col: f'{col}_shap')
     .assign(base_value=vals.base_values, **X_reg)
     .pipe(lambda df_:
        sns.catplot(x='make', y='make_shap', data=df_, alpha=.5, 
                    hue='year', palette='RdBu',
        order=makes))
)

@connortann connortann added the visualization Relating to plotting label Aug 19, 2024
@hypostulate
Copy link
Contributor

This should be solved by #3706.

@thatlittleboy
Copy link
Collaborator

Resolved by #3706

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Indicates an unexpected problem or unintended behaviour visualization Relating to plotting
Projects
None yet
Development

No branches or pull requests

4 participants