Skip to content

Commit

Permalink
allow coloring subset of rows
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherRussell committed Mar 10, 2024
1 parent 984cca2 commit fe95638
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions great_tables/_data_color/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def is_numeric_or_none(x: Any) -> bool:
def data_color_mpl(
self: GTSelf,
columns: Union[str, List[str], None] = None,
rows: Union[int, list[int], None] = None,
cmap: Colormap | str | list[str] | None = None,
norm: Normalize | Callable[[float], float] | None = None,
na_color: Optional[str] = None,
Expand All @@ -53,6 +54,9 @@ def data_color_mpl(
columns
The columns to target. Can either be a single column name or a series of column names
provided in a list.
rows
The rows to target. Can either be a single row index or a series of row indices provided in a
list.
cmap
The name of the colormap to use. This should be a valid matplotlib colormap name (e.g.,
`"viridis"`, `"plasma"`, `"inferno"`, `"magma"`, etc.). Can also be a
Expand Down Expand Up @@ -100,6 +104,8 @@ def data_color_mpl(
else:
columns_resolved = columns

rows_to_color = get_row_to_color_flags(rows, data_table)

# check all columns are numeric
for col in columns_resolved:
column_values = data_table[col].to_list()
Expand All @@ -115,34 +121,50 @@ def data_color_mpl(
for col in columns_resolved:
column_values = data_table[col].to_list()

color_values = []
for value in column_values:
for row_nr, (value, should_color_row) in enumerate(zip(column_values, rows_to_color)):
if not should_color_row:
continue

if is_na(data_table, value):
color_values.append(na_color)
color = na_color
else:
scaled_value = norm(value)
color_no_alpha = colormap(scaled_value)
color = (*color_no_alpha[:3], alpha) # in RGBA format last value is alpha
color = to_hex(color, keep_alpha=True)
color_values.append(color)

for i, _ in enumerate(color_values):
if autocolor_text:
fgnd_color = _ideal_fgnd_color(bgnd_color=color_values[i])
fgnd_color = _ideal_fgnd_color(bgnd_color=color)

self = self.tab_style(
style=[text(color=fgnd_color), fill(color=color_values[i])],
locations=body(columns=col, rows=[i]),
style=[text(color=fgnd_color), fill(color=color)],
locations=body(columns=col, rows=[row_nr]),
)

else:
self = self.tab_style(
style=fill(color=color_values[i]), locations=body(columns=col, rows=[i])
style=fill(color=color), locations=body(columns=col, rows=[row_nr])
)

return self


def get_row_to_color_flags(
rows: list[int] | int | None, data_table: pl.DataFrame | pd.DataFrame
) -> list[bool]:
if rows is None:
rows_to_color = [True] * len(data_table)
elif isinstance(rows, int):
rows_to_color = [i == rows for i in range(len(data_table))]
elif isinstance(rows, list) and all(isinstance(r, int) for r in rows):
rows_to_color = [i in rows for i in range(len(data_table))]
else:
raise ValueError(
f"Invalid rows provided ({rows}). Please provide a single row index, a list of row indices, or None."
)
return rows_to_color


def _get_default_norm(data: pl.DataFrame | pd.DataFrame) -> Normalize:
if isinstance(data, pl.DataFrame):
vmin = data.min().min_horizontal()[0]
Expand Down

0 comments on commit fe95638

Please sign in to comment.