Skip to content

Commit

Permalink
add statistics or raise
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed Jun 21, 2024
1 parent d7ecc7a commit 28cd208
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool]
[tool.poetry]
name = "slist"
version = "0.3.9"
version = "0.3.10"
homepage = "https://github.com/thejaminator/slist"
description = "A typesafe list with more method chaining!"
authors = ["James Chua <chuajamessh@gmail.com>"]
Expand Down
40 changes: 37 additions & 3 deletions slist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import concurrent.futures
from ctypes import cast
from dataclasses import dataclass
import random
import re
import statistics
Expand Down Expand Up @@ -30,6 +30,7 @@
# Needed for https://github.com/python/typing_extensions/issues/7
from typing_extensions import NamedTuple


A = TypeVar("A")
B = TypeVar("B")
C = TypeVar("C")
Expand Down Expand Up @@ -66,6 +67,18 @@ def __add__(self: A, other: A, /) -> A:
CanAdd = TypeVar("CanAdd", bound=Addable)


@dataclass(frozen=True)
class AverageStats:
average: float
standard_deviation: float
upper_confidence_interval_95: float
lower_confidence_interval_95: float
count: int

def __str__(self) -> str:
return f"Average: {self.average}, SD: {self.standard_deviation}, 95% CI: ({self.lower_confidence_interval_95}, {self.upper_confidence_interval_95})"


class Comparable(Protocol):
def __lt__(self: CanCompare, other: CanCompare, /) -> bool:
...
Expand Down Expand Up @@ -662,7 +675,7 @@ def par_map(self, func: Callable[[A], B], executor: concurrent.futures.Executor)
async def par_map_async(self, func: Callable[[A], typing.Awaitable[B]]) -> Slist[B]:
"""Applies the async function to each element. Awaits for all results."""
return Slist(await asyncio.gather(*[func(item) for item in self]))

async def gather(self: Sequence[typing.Awaitable[B]]) -> Slist[B]:
"""Awaits for all results"""
return Slist(await asyncio.gather(*self))
Expand Down Expand Up @@ -706,12 +719,33 @@ def average(
def average_or_raise(
self: Sequence[Union[int, float, bool]],
) -> float:
"""Returns None when the list is empty"""
"""Raises when the list is empty"""
this = typing.cast(Slist[Union[int, float, bool]], self)
if this.length == 0:
raise ValueError("Cannot get average of empty list")
return this.sum() / this.length

def statistics_or_raise(
self: Sequence[Union[int, float, bool]],
) -> AverageStats:
"""Raises when the list is empty"""
this = typing.cast(Slist[Union[int, float, bool]], self)
if this.length == 0:
raise ValueError("Cannot get average of empty list")
average = this.average_or_raise()
standard_deviation = this.standard_deviation()
assert standard_deviation is not None
standard_error = standard_deviation / ((this.length) ** 0.5)
upper_ci = average + 1.96 * standard_error
lower_ci = average - 1.96 * standard_error
return AverageStats(
average=average,
standard_deviation=standard_deviation,
upper_confidence_interval_95=upper_ci,
lower_confidence_interval_95=lower_ci,
count=this.length,
)

def standard_deviation(self: Slist[Union[int, float]]) -> Optional[float]:
"""Returns None when the list is empty"""
return statistics.stdev(self) if self.length > 0 else None
Expand Down
16 changes: 16 additions & 0 deletions tests/test_slist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from slist import Slist, identity
import numpy as np

Check failure on line 4 in tests/test_slist.py

View workflow job for this annotation

GitHub Actions / test (3.8, ubuntu-latest)

Import "numpy" could not be resolved (reportMissingImports)

Check failure on line 4 in tests/test_slist.py

View workflow job for this annotation

GitHub Actions / test (3.9, ubuntu-latest)

Import "numpy" could not be resolved (reportMissingImports)

Check failure on line 4 in tests/test_slist.py

View workflow job for this annotation

GitHub Actions / test (3.10, ubuntu-latest)

Import "numpy" could not be resolved (reportMissingImports)

Check failure on line 4 in tests/test_slist.py

View workflow job for this annotation

GitHub Actions / test (3.9, macos-latest)

Import "numpy" could not be resolved (reportMissingImports)


def test_split_proportion():
Expand Down Expand Up @@ -285,3 +286,18 @@ def test_product():
(5, 5),
]
)


def test_statistics_or_raise():
numbers = Slist([1, 2, 3, 4, 5])
results = numbers.statistics_or_raise()
assert results.average == 3
assert results.count == 5

# convert the above to use numpy roughly equal
assert np.isclose(results.upper_confidence_interval_95, 4.38, atol=0.01)
assert np.isclose(results.lower_confidence_interval_95, 1.61, atol=0.01)

empty = Slist([])
with pytest.raises(ValueError):
empty.statistics_or_raise()

0 comments on commit 28cd208

Please sign in to comment.