Skip to content

Commit

Permalink
entropy normalization arg (#3369)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 11, 2022
1 parent 7166080 commit ca45888
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 18 deletions.
15 changes: 10 additions & 5 deletions polars/polars-core/src/series/ops/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@ impl Series {
/// Compute the entropy as `-sum(pk * log(pk)`.
/// where `pk` are discrete probabilities.
#[cfg_attr(docsrs, doc(cfg(feature = "log")))]
pub fn entropy(&self, base: f64) -> Option<f64> {
pub fn entropy(&self, base: f64, normalize: bool) -> Option<f64> {
match self.dtype() {
DataType::Float32 | DataType::Float64 => {
let pk = self;
let sum = pk.sum_as_series();

let pk = if sum.get(0).extract::<f64>()? != 1.0 {
pk / &sum
let pk = if normalize {
let sum = pk.sum_as_series();

if sum.get(0).extract::<f64>()? != 1.0 {
pk / &sum
} else {
pk.clone()
}
} else {
pk.clone()
};
Expand All @@ -44,7 +49,7 @@ impl Series {
_ => self
.cast(&DataType::Float64)
.ok()
.and_then(|s| s.entropy(base)),
.and_then(|s| s.entropy(base, normalize)),
}
}
}
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1810,9 +1810,9 @@ impl Expr {
#[cfg_attr(docsrs, doc(cfg(feature = "log")))]
/// Compute the entropy as `-sum(pk * log(pk)`.
/// where `pk` are discrete probabilities.
pub fn entropy(self, base: f64) -> Self {
pub fn entropy(self, base: f64, normalize: bool) -> Self {
self.apply(
move |s| Ok(Series::new(s.name(), [s.entropy(base)])),
move |s| Ok(Series::new(s.name(), [s.entropy(base, normalize)])),
GetOutput::map_dtype(|dt| {
if matches!(dt, DataType::Float32) {
DataType::Float32
Expand Down
9 changes: 5 additions & 4 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2884,19 +2884,20 @@ def log(self, base: float = math.e) -> "Expr":
"""
return wrap_expr(self._pyexpr.log(base))

def entropy(self, base: float = math.e) -> "Expr":
def entropy(self, base: float = math.e, normalize: bool = False) -> "Expr":
"""
Compute the entropy as `-sum(pk * log(pk)`.
where `pk` are discrete probabilities.
This routine will normalize pk if they don’t sum to 1.
Parameters
----------
base
Given base, defaults to `e`
normalize
Normalize pk if it doesn't sum to 1.
"""
return wrap_expr(self._pyexpr.entropy(base))
return wrap_expr(self._pyexpr.entropy(base, normalize))

# Below are the namespaces defined. Keep these at the end of the definition of Expr, as to not confuse mypy with
# the type annotation `str` with the namespace "str"
Expand Down
15 changes: 11 additions & 4 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,25 +947,32 @@ def unique_counts(self) -> "Series":
"""
return pli.select(pli.lit(self).unique_counts()).to_series()

def entropy(self, base: float = math.e) -> Optional[float]:
def entropy(self, base: float = math.e, normalize: bool = False) -> Optional[float]:
"""
Compute the entropy as `-sum(pk * log(pk)`.
where `pk` are discrete probabilities.
This routine will normalize pk if they don’t sum to 1.
Parameters
----------
base
Given base, defaults to `e`
normalize
Normalize pk if it doesn't sum to 1.
Examples
--------
>>> a = pl.Series([0.99, 0.005, 0.005])
>>> a.entropy()
>>> a.entropy(normalize=True)
0.06293300616044681
>>> b = pl.Series([0.65, 0.10, 0.25])
>>> b.entropy()
>>> b.entropy(normalize=True)
0.8568409950394724
"""
return pli.select(pli.lit(self).entropy(base)).to_series()[0]
return pli.select(pli.lit(self).entropy(base, normalize)).to_series()[0]

@property
def name(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1358,8 +1358,8 @@ impl PyExpr {
self.inner.clone().log(base).into()
}

pub fn entropy(&self, base: f64) -> Self {
self.inner.clone().entropy(base).into()
pub fn entropy(&self, base: f64, normalize: bool) -> Self {
self.inner.clone().entropy(base, normalize).into()
}
pub fn hash(&self, seed: usize) -> Self {
self.inner.clone().hash(seed).into()
Expand Down
4 changes: 3 additions & 1 deletion py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def test_entropy() -> None:
)

assert (
df.groupby("group", maintain_order=True).agg(pl.col("id").entropy())
df.groupby("group", maintain_order=True).agg(
pl.col("id").entropy(normalize=True)
)
).frame_equal(
pl.DataFrame(
{"group": ["A", "B"], "id": [1.0397207708399179, 1.371381017771811]}
Expand Down

0 comments on commit ca45888

Please sign in to comment.