Skip to content

Commit

Permalink
Merge pull request #325 from jrycw/fix-35
Browse files Browse the repository at this point in the history
Enhance test coverage for column selection
  • Loading branch information
rich-iannone committed May 3, 2024
2 parents aa14ccf + e1256f4 commit 5ea06af
Showing 1 changed file with 84 additions and 46 deletions.
130 changes: 84 additions & 46 deletions tests/test_spanners.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import pandas as pd
import polars as pl
import polars.selectors as cs
import pytest
from great_tables import GT
from great_tables._gt_data import Boxhead, ColInfo, ColInfoTypeEnum, SpannerInfo, Spanners
from great_tables._spanners import (
cols_hide,
cols_move,
cols_move_to_end,
cols_move_to_start,
empty_spanner_matrix,
spanners_print_matrix,
tab_spanner,
)


@pytest.fixture
def spanners():
def spanners() -> Spanners:
return Spanners(
[
SpannerInfo(spanner_id="a", spanner_level=0, vars=["col1"], built="A"),
Expand All @@ -22,7 +26,7 @@ def spanners():


@pytest.fixture
def boxhead():
def boxhead() -> Boxhead:
return Boxhead(
[
ColInfo(var="col1"),
Expand All @@ -33,19 +37,19 @@ def boxhead():
)


def test_spanners_next_level_above_first(spanners):
def test_spanners_next_level_above_first(spanners: Spanners):
assert spanners.next_level(["col1"]) == 1


def test_spanners_next_level_above_second(spanners):
def test_spanners_next_level_above_second(spanners: Spanners):
assert spanners.next_level(["col2"]) == 2


def test_spanners_next_level_unique(spanners):
def test_spanners_next_level_unique(spanners: Spanners):
assert spanners.next_level(["col3"]) == 0


def test_spanners_print_matrix(spanners, boxhead):
def test_spanners_print_matrix(spanners: Spanners, boxhead: Boxhead):
mat, vars = spanners_print_matrix(spanners, boxhead)
assert vars == ["col1", "col2", "col3"]
assert mat == [
Expand All @@ -55,7 +59,7 @@ def test_spanners_print_matrix(spanners, boxhead):
]


def test_spanners_print_matrix_arg_omit_columns_row(spanners, boxhead):
def test_spanners_print_matrix_arg_omit_columns_row(spanners: Spanners, boxhead: Boxhead):
mat, vars = spanners_print_matrix(spanners, boxhead, omit_columns_row=True)
assert vars == ["col1", "col2", "col3"]
assert mat == [
Expand All @@ -64,7 +68,7 @@ def test_spanners_print_matrix_arg_omit_columns_row(spanners, boxhead):
]


def test_spanners_print_matrix_arg_include_hidden(spanners, boxhead):
def test_spanners_print_matrix_arg_include_hidden(spanners: Spanners, boxhead: Boxhead):
mat, vars = spanners_print_matrix(spanners, boxhead, include_hidden=True)
assert vars == ["col1", "col2", "col3", "col4"]
assert mat == [
Expand Down Expand Up @@ -150,44 +154,6 @@ def test_tab_spanners_with_gather():
assert [col.var for col in new_gt._boxhead] == ["a", "c", "b"]


def test_cols_hide():
df = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
src_gt = GT(df)

new_gt = cols_hide(src_gt, columns=["a"])
assert [col.var for col in new_gt._boxhead if col.visible] == ["b", "c"]

new_gt = cols_hide(src_gt, columns=["a", "b"])
assert [col.var for col in new_gt._boxhead if col.visible] == ["c"]

import polars as pl
import polars.selectors as cs

df = pl.DataFrame({"col1": [1, 2], "col2": [3, 4], "abc": [5, 6]})
src_gt = GT(df)
new_gt = cols_hide(src_gt, columns=cs.starts_with("col"))
assert [col.var for col in new_gt._boxhead if col.visible] == ["abc"]


def test_cols_move():
df = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
src_gt = GT(df)

new_gt = cols_move(src_gt, columns=["a"], after="b")
assert [col.var for col in new_gt._boxhead] == ["b", "a", "c"]


def test_cols_move_polars():
import polars as pl
import polars.selectors as cs

df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
src_gt = GT(df)

new_gt = cols_move(src_gt, columns=cs.starts_with("a"), after="b")
assert [col.var for col in new_gt._boxhead] == ["b", "a", "c"]


def test_cols_width_partial_set():
df = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
gt_tbl = GT(df).cols_width({"a": "10px"})
Expand Down Expand Up @@ -231,3 +197,75 @@ def test_cols_width_fully_set_pct_2():
assert gt_tbl._boxhead[0].column_width == "10%"
assert gt_tbl._boxhead[1].column_width == "10%"
assert gt_tbl._boxhead[2].column_width == "40%"


@pytest.mark.parametrize("df_lib, columns", [(pd, "a"), (pl, cs.starts_with("a"))])
def test_cols_move_single_col(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_move(src_gt, columns=columns, after="b")
assert [col.var for col in new_gt._boxhead] == ["b", "a", "c", "d"]


@pytest.mark.parametrize(
"df_lib, columns", [(pd, ["a", "d"]), (pl, cs.starts_with("a") | cs.ends_with("d"))]
)
def test_cols_move_multi_cols(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_move(src_gt, columns=columns, after="b")
assert [col.var for col in new_gt._boxhead] == ["b", "a", "d", "c"]


@pytest.mark.parametrize("df_lib, columns", [(pd, "c"), (pl, cs.starts_with("c"))])
def test_cols_move_to_start_single_col(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_move_to_start(src_gt, columns=columns)
assert [col.var for col in new_gt._boxhead] == ["c", "a", "b", "d"]


@pytest.mark.parametrize(
"df_lib, columns", [(pd, ["c", "d"]), (pl, cs.starts_with("c") | cs.ends_with("d"))]
)
def test_cols_move_to_start_multi_cols(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_move_to_start(src_gt, columns=columns)
assert [col.var for col in new_gt._boxhead] == ["c", "d", "a", "b"]


@pytest.mark.parametrize("df_lib, columns", [(pd, "c"), (pl, cs.starts_with("c"))])
def test_cols_move_to_end_single_col(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_move_to_end(src_gt, columns=columns)
assert [col.var for col in new_gt._boxhead] == ["a", "b", "d", "c"]


@pytest.mark.parametrize(
"df_lib, columns", [(pd, ["a", "c"]), (pl, cs.starts_with("a") | cs.ends_with("c"))]
)
def test_cols_move_to_end_multi_cols(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_move_to_end(src_gt, columns=columns)
assert [col.var for col in new_gt._boxhead] == ["b", "d", "a", "c"]


@pytest.mark.parametrize("df_lib, columns", [(pd, "c"), (pl, cs.starts_with("c"))])
def test_cols_hide_single_col(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_hide(src_gt, columns=columns)
assert [col.var for col in new_gt._boxhead if col.visible] == ["a", "b", "d"]


@pytest.mark.parametrize(
"df_lib, columns", [(pd, ["a", "d"]), (pl, cs.starts_with("a") | cs.ends_with("d"))]
)
def test_cols_hide_multi_cols(df_lib, columns):
df = getattr(df_lib, "DataFrame")({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})
src_gt = GT(df)
new_gt = cols_hide(src_gt, columns=columns)
assert [col.var for col in new_gt._boxhead if col.visible] == ["b", "c"]

0 comments on commit 5ea06af

Please sign in to comment.