Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions _python_utils_tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from unittest.mock import MagicMock

import pytest

from python_utils.decorators import sample


@pytest.fixture
def random(monkeypatch):
mock = MagicMock()
monkeypatch.setattr("python_utils.decorators.random.random", mock, raising=True)
return mock


def test_sample_called(random):
demo_function = MagicMock()
decorated = sample(0.5)(demo_function)
random.return_value = 0.4
decorated()
random.return_value = 0.0
decorated()
args = [1, 2]
kwargs = {"1": 1, "2": 2}
decorated(*args, **kwargs)
demo_function.assert_called_with(*args, **kwargs)
assert demo_function.call_count == 3


def test_sample_not_called(random):
demo_function = MagicMock()
decorated = sample(0.5)(demo_function)
random.return_value = 0.5
decorated()
random.return_value = 1.0
decorated()
assert demo_function.call_count == 0
29 changes: 29 additions & 0 deletions python_utils/decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import logging
import random
from . import types


Expand Down Expand Up @@ -92,3 +94,30 @@ def __listify(*args, **kwargs):
return __listify

return _listify


def sample(sample_rate: float):
'''
Limit calls to a function based on given sample rate.
Number of calls to the function will be roughly equal to
sample_rate percentage.

Usage:

>>> @sample(0.5)
... def demo_function(*args, **kwargs):
... return 1

Calls to *demo_function* will be limited to 50% approximatly.

'''
def _sample(function):
@functools.wraps(function)
def __sample(*args, **kwargs):
if random.random() < sample_rate:
return function(*args, **kwargs)
else:
logging.debug('Skipped execution of %r(%r, %r) due to sampling', function, args, kwargs) # noqa: E501

return __sample
return _sample