Skip to content
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Other enhancements
- :meth:`DataFrame.applymap` now supports ``na_action`` (:issue:`23803`)
- :class:`Index` with object dtype supports division and multiplication (:issue:`34160`)
- :meth:`DataFrame.explode` and :meth:`Series.explode` now support exploding of sets (:issue:`35614`)
-
- `Styler` now allows direct CSS class name addition to individual data cells (:issue:`36159`)

.. _whatsnew_120.api_breaking.python:

Expand Down
69 changes: 68 additions & 1 deletion pandas/io/formats/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def __init__(
self.cell_ids = cell_ids
self.na_rep = na_rep

self.cell_context: Dict[str, Any] = {}

# display_funcs maps (row, col) -> formatting function

def default_display_func(x):
Expand Down Expand Up @@ -262,7 +264,7 @@ def format_attr(pair):
idx_lengths = _get_level_lengths(self.index)
col_lengths = _get_level_lengths(self.columns, hidden_columns)

cell_context = dict()
cell_context = self.cell_context

n_rlvls = self.data.index.nlevels
n_clvls = self.data.columns.nlevels
Expand Down Expand Up @@ -499,6 +501,70 @@ def format(self, formatter, subset=None, na_rep: Optional[str] = None) -> "Style
self._display_funcs[(i, j)] = formatter
return self

def set_td_classes(self, classes: DataFrame) -> "Styler":
"""
Add string based CSS class names to data cells that will appear within the
`Styler` HTML result. These classes are added within specified `<td>` elements.

Parameters
----------
classes : DataFrame
DataFrame containing strings that will be translated to CSS classes,
mapped by identical column and index values that must exist on the
underlying `Styler` data. None, NaN values, and empty strings will
be ignored and not affect the rendered HTML.

Returns
-------
self : Styler

Examples
--------
>>> df = pd.DataFrame(data=[[1, 2, 3], [4, 5, 6]], columns=["A", "B", "C"])
>>> classes = pd.DataFrame([
... ["min-val red", "", "blue"],
... ["red", None, "blue max-val"]
... ], index=df.index, columns=df.columns)
>>> df.style.set_td_classes(classes)

Using `MultiIndex` columns and a `classes` `DataFrame` as a subset of the
underlying,

>>> df = pd.DataFrame([[1,2],[3,4]], index=["a", "b"],
... columns=[["level0", "level0"], ["level1a", "level1b"]])
>>> classes = pd.DataFrame(["min-val"], index=["a"],
... columns=[["level0"],["level1a"]])
>>> df.style.set_td_classes(classes)

Form of the output with new additional css classes,

>>> df = pd.DataFrame([[1]])
>>> css = pd.DataFrame(["other-class"])
>>> s = Styler(df, uuid="_", cell_ids=False).set_td_classes(css)
>>> s.hide_index().render()
'<style type="text/css" ></style>'
'<table id="T__" >'
' <thead>'
' <tr><th class="col_heading level0 col0" >0</th></tr>'
' </thead>'
' <tbody>'
' <tr><td class="data row0 col0 other-class" >1</td></tr>'
' </tbody>'
'</table>'

"""
classes = classes.reindex_like(self.data)

mask = (classes.isna()) | (classes.eq(""))
self.cell_context["data"] = {
r: {c: [str(classes.iloc[r, c])]}
for r, rn in enumerate(classes.index)
for c, cn in enumerate(classes.columns)
if not mask.iloc[r, c]
}

return self

def render(self, **kwargs) -> str:
"""
Render the built up styles to HTML.
Expand Down Expand Up @@ -609,6 +675,7 @@ def clear(self) -> None:
Returns None.
"""
self.ctx.clear()
self.cell_context = {}
self._todo = []

def _compute(self):
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/io/formats/test_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,27 @@ def test_no_cell_ids(self):
s = styler.render() # render twice to ensure ctx is not updated
assert s.find('<td class="data row0 col0" >') != -1

@pytest.mark.parametrize(
"classes",
[
DataFrame(
data=[["", "test-class"], [np.nan, None]],
columns=["A", "B"],
index=["a", "b"],
),
DataFrame(data=[["test-class"]], columns=["B"], index=["a"]),
DataFrame(data=[["test-class", "unused"]], columns=["B", "C"], index=["a"]),
],
)
def test_set_data_classes(self, classes):
# GH 36159
df = DataFrame(data=[[0, 1], [2, 3]], columns=["A", "B"], index=["a", "b"])
s = Styler(df, uuid="_", cell_ids=False).set_td_classes(classes).render()
assert '<td class="data row0 col0" >0</td>' in s
assert '<td class="data row0 col1 test-class" >1</td>' in s
assert '<td class="data row1 col0" >2</td>' in s
assert '<td class="data row1 col1" >3</td>' in s

def test_colspan_w3(self):
# GH 36223
df = pd.DataFrame(data=[[1, 2]], columns=[["l0", "l0"], ["l1a", "l1b"]])
Expand Down