Skip to content

Commit

Permalink
fix(python): Validate estimated_size parameter (#6018)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jan 3, 2023
1 parent 53d2e4c commit f1e5f6a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,7 +2273,7 @@ def estimated_size(self, unit: SizeUnit = "b") -> int | float:
"""
sz = self._df.estimated_size()
return scale_bytes(sz, to=unit)
return scale_bytes(sz, unit)

def transpose(
self: DF,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ def estimated_size(self, unit: SizeUnit = "b") -> int | float:
"""
sz = self._s.estimated_size()
return scale_bytes(sz, to=unit)
return scale_bytes(sz, unit)

def sqrt(self) -> Series:
"""
Expand Down
26 changes: 15 additions & 11 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,15 +383,19 @@ def __get__(self, instance: Any, cls: type[NS]) -> NS: # type: ignore[override]
)


def scale_bytes(sz: int, to: SizeUnit) -> int | float:
def scale_bytes(sz: int, unit: SizeUnit) -> int | float:
"""Scale size in bytes to other size units (eg: "kb", "mb", "gb", "tb")."""
scaling_factor = {
"b": 1,
"k": 1024,
"m": 1024**2,
"g": 1024**3,
"t": 1024**4,
}[to[0]]
if scaling_factor > 1:
return sz / scaling_factor
return sz
if unit in {"b", "bytes"}:
return sz
elif unit in {"kb", "kilobytes"}:
return sz / 1024
elif unit in {"mb", "megabytes"}:
return sz / 1024**2
elif unit in {"gb", "gigabytes"}:
return sz / 1024**3
elif unit in {"tb", "terabytes"}:
return sz / 1024**4
else:
raise ValueError(
f"unit must be one of {{'b', 'kb', 'mb', 'gb', 'tb'}}, got {unit!r}"
)
3 changes: 3 additions & 0 deletions py-polars/tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ def test_estimated_size() -> None:
assert s.estimated_size("mb") == (df.estimated_size("kb") / 1024)
assert s.estimated_size("gb") == (df.estimated_size("mb") / 1024)
assert s.estimated_size("tb") == (df.estimated_size("gb") / 1024)

with pytest.raises(ValueError):
s.estimated_size("milkshake") # type: ignore[arg-type]

0 comments on commit f1e5f6a

Please sign in to comment.