From 1b23d0f71f1edd993d532a2a048909b70fe1a994 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Fri, 12 Apr 2019 08:44:39 -0300 Subject: [PATCH 01/38] full refactor --- setup.cfg | 15 +- setup.py | 14 +- splitio/__init__.py | 55 +- splitio/api.py | 263 --- splitio/api/__init__.py | 21 + splitio/api/client.py | 141 ++ splitio/api/events.py | 73 + splitio/api/impressions.py | 82 + splitio/api/segments.py | 51 + splitio/api/splits.py | 48 + splitio/api/telemetry.py | 132 ++ splitio/bin/synchronizer.py | 133 -- splitio/brokers.py | 848 ---------- splitio/cache.py | 339 ---- splitio/{bin => client}/__init__.py | 0 splitio/{clients.py => client/client.py} | 172 +- splitio/client/config.py | 43 + splitio/client/factory.py | 361 +++++ splitio/{ => client}/input_validator.py | 47 +- splitio/client/key.py | 22 + splitio/client/listener.py | 68 + splitio/client/localhost.py | 207 +++ splitio/client/manager.py | 68 + splitio/client/util.py | 34 + splitio/config.py | 133 -- splitio/{tests => engine}/__init__.py | 0 splitio/{ => engine}/evaluator.py | 62 +- splitio/{ => engine}/hashfns/__init__.py | 16 +- splitio/{ => engine}/hashfns/legacy.py | 13 +- splitio/engine/hashfns/murmur3py.py | 76 + splitio/{ => engine}/splitters.py | 29 +- splitio/events.py | 79 - splitio/exceptions.py | 4 + splitio/factories.py | 121 -- splitio/hashfns/murmur3py.py | 70 - splitio/impressions.py | 516 +----- splitio/key.py | 23 +- splitio/managers.py | 374 ----- splitio/matchers.py | 777 --------- splitio/metrics.py | 522 ------ .../{update_scripts => models}/__init__.py | 0 splitio/models/datatypes.py | 61 + splitio/models/events.py | 17 + splitio/models/grammar/__init__.py | 0 splitio/models/grammar/condition.py | 135 ++ splitio/models/grammar/matchers/__init__.py | 70 + splitio/models/grammar/matchers/base.py | 122 ++ splitio/models/grammar/matchers/keys.py | 98 ++ splitio/models/grammar/matchers/misc.py | 102 ++ splitio/models/grammar/matchers/numeric.py | 256 +++ splitio/models/grammar/matchers/sets.py | 208 +++ splitio/models/grammar/matchers/string.py | 275 ++++ splitio/models/grammar/partitions.py | 58 + splitio/models/impressions.py | 48 + splitio/models/segments.py | 86 + splitio/models/splits.py | 223 +++ splitio/models/telemetry.py | 27 + splitio/prefix_decorator.py | 137 -- splitio/redis_support.py | 937 ----------- splitio/segments.py | 458 ------ splitio/splits.py | 1197 -------------- splitio/storage/__init__.py | 293 ++++ splitio/storage/adapters/__init__.py | 0 splitio/storage/adapters/redis.py | 410 +++++ splitio/storage/adapters/uwsgi_cache.py | 134 ++ splitio/storage/inmemmory.py | 395 +++++ splitio/storage/redis.py | 535 ++++++ splitio/storage/uwsgi.py | 583 +++++++ splitio/tasks.py | 273 ---- splitio/tasks/__init__.py | 30 + splitio/tasks/events_sync.py | 103 ++ splitio/tasks/impressions_sync.py | 104 ++ splitio/tasks/segment_sync.py | 106 ++ splitio/tasks/split_sync.py | 82 + splitio/tasks/telemetry_sync.py | 70 + splitio/tasks/util/__init__.py | 0 splitio/{ => tasks/util}/asynctask.py | 115 +- splitio/tasks/util/workerpool.py | 121 ++ splitio/tasks/uwsgi_wrappers.py | 186 +++ splitio/tests/algoSplits.json | 264 --- splitio/tests/segmentChanges.json | 1 - splitio/tests/splitChanges.json | 1440 ----------------- splitio/tests/splitChangesReadOnly.json | 46 - .../tests/splitCustomImpressionListener.json | 46 - splitio/tests/splitGetTreatments.json | 46 - splitio/tests/test_api.py | 492 ------ splitio/tests/test_cache.py | 147 -- splitio/tests/test_clients.py | 1284 --------------- splitio/tests/test_events.py | 31 - splitio/tests/test_factories.py | 61 - splitio/tests/test_get_treatments.py | 105 -- splitio/tests/test_impression_listener.py | 196 --- splitio/tests/test_impressions.py | 691 -------- splitio/tests/test_input_validator.py | 776 --------- splitio/tests/test_matchers.py | 1125 ------------- splitio/tests/test_metrics.py | 721 --------- splitio/tests/test_prefix_decorator.py | 133 -- splitio/tests/test_redis_cache.py | 106 -- splitio/tests/test_redis_support.py | 432 ----- splitio/tests/test_segments.py | 396 ----- splitio/tests/test_splits.py | 1123 ------------- splitio/tests/test_splitters.py | 224 --- splitio/tests/test_tasks.py | 370 ----- splitio/tests/test_transformers.py | 221 --- splitio/tests/test_uwsgi.py | 334 ---- splitio/tests/utils.py | 64 - splitio/transformers.py | 135 -- splitio/treatments.py | 5 - splitio/update_scripts/post_impressions.py | 40 - splitio/update_scripts/post_metrics.py | 40 - splitio/update_scripts/update_segments.py | 42 - splitio/update_scripts/update_splits.py | 44 - splitio/utils.py | 28 - splitio/uwsgi.py | 1053 ------------ splitio/version.py | 2 +- tests/api/test_events.py | 57 + tests/api/test_httpclient.py | 139 ++ tests/api/test_impressions_api.py | 63 + tests/api/test_segments_api.py | 27 + tests/api/test_splits_api.py | 27 + tests/api/test_telemetry.py | 118 ++ tests/client/test_client.py | 189 +++ tests/client/test_factory.py | 324 ++++ tests/client/test_input_validator.py | 597 +++++++ .../engine/files}/murmur3-custom-uuids.csv | 0 ...rmur3-sample-data-non-alpha-numeric-v2.csv | 0 .../engine/files}/murmur3-sample-data-v2.csv | 0 .../sample-data-non-alpha-numeric.jsonl | 0 .../engine/files}/sample-data.jsonl | 0 tests/engine/test_evaluator.py | 122 ++ tests/engine/test_hashfns.py | 107 ++ tests/engine/test_splitter.py | 51 + .../models/grammar/files}/regex.txt | 0 tests/models/grammar/test_conditions.py | 78 + tests/models/grammar/test_matchers.py | 886 ++++++++++ tests/models/grammar/test_partitions.py | 24 + tests/models/test_splits.py | 112 ++ tests/storage/adapters/test_redis_adapter.py | 171 ++ tests/storage/test_inmemory_storage.py | 268 +++ tests/storage/test_redis.py | 300 ++++ tests/storage/test_uwsgi.py | 244 +++ tests/tasks/test_events_sync.py | 40 + tests/tasks/test_impressions_sync.py | 38 + tests/tasks/test_segment_sync.py | 92 ++ tests/tasks/test_split_sync.py | 119 ++ tests/tasks/test_telemetry_sync.py | 57 + tests/tasks/util/test_asynctask.py | 118 ++ tests/tasks/util/test_workerpool.py | 53 + 148 files changed, 11098 insertions(+), 19694 deletions(-) delete mode 100644 splitio/api.py create mode 100644 splitio/api/__init__.py create mode 100644 splitio/api/client.py create mode 100644 splitio/api/events.py create mode 100644 splitio/api/impressions.py create mode 100644 splitio/api/segments.py create mode 100644 splitio/api/splits.py create mode 100644 splitio/api/telemetry.py delete mode 100644 splitio/bin/synchronizer.py delete mode 100644 splitio/brokers.py delete mode 100644 splitio/cache.py rename splitio/{bin => client}/__init__.py (100%) rename splitio/{clients.py => client/client.py} (62%) create mode 100644 splitio/client/config.py create mode 100644 splitio/client/factory.py rename splitio/{ => client}/input_validator.py (93%) create mode 100644 splitio/client/key.py create mode 100644 splitio/client/listener.py create mode 100644 splitio/client/localhost.py create mode 100644 splitio/client/manager.py create mode 100644 splitio/client/util.py delete mode 100644 splitio/config.py rename splitio/{tests => engine}/__init__.py (100%) rename splitio/{ => engine}/evaluator.py (66%) rename splitio/{ => engine}/hashfns/__init__.py (67%) rename splitio/{ => engine}/hashfns/legacy.py (57%) create mode 100644 splitio/engine/hashfns/murmur3py.py rename splitio/{ => engine}/splitters.py (68%) delete mode 100644 splitio/events.py delete mode 100644 splitio/factories.py delete mode 100644 splitio/hashfns/murmur3py.py delete mode 100644 splitio/managers.py delete mode 100644 splitio/matchers.py delete mode 100644 splitio/metrics.py rename splitio/{update_scripts => models}/__init__.py (100%) create mode 100644 splitio/models/datatypes.py create mode 100644 splitio/models/events.py create mode 100644 splitio/models/grammar/__init__.py create mode 100644 splitio/models/grammar/condition.py create mode 100644 splitio/models/grammar/matchers/__init__.py create mode 100644 splitio/models/grammar/matchers/base.py create mode 100644 splitio/models/grammar/matchers/keys.py create mode 100644 splitio/models/grammar/matchers/misc.py create mode 100644 splitio/models/grammar/matchers/numeric.py create mode 100644 splitio/models/grammar/matchers/sets.py create mode 100644 splitio/models/grammar/matchers/string.py create mode 100644 splitio/models/grammar/partitions.py create mode 100644 splitio/models/impressions.py create mode 100644 splitio/models/segments.py create mode 100644 splitio/models/splits.py create mode 100644 splitio/models/telemetry.py delete mode 100644 splitio/prefix_decorator.py delete mode 100644 splitio/redis_support.py delete mode 100644 splitio/segments.py delete mode 100644 splitio/splits.py create mode 100644 splitio/storage/__init__.py create mode 100644 splitio/storage/adapters/__init__.py create mode 100644 splitio/storage/adapters/redis.py create mode 100644 splitio/storage/adapters/uwsgi_cache.py create mode 100644 splitio/storage/inmemmory.py create mode 100644 splitio/storage/redis.py create mode 100644 splitio/storage/uwsgi.py delete mode 100644 splitio/tasks.py create mode 100644 splitio/tasks/__init__.py create mode 100644 splitio/tasks/events_sync.py create mode 100644 splitio/tasks/impressions_sync.py create mode 100644 splitio/tasks/segment_sync.py create mode 100644 splitio/tasks/split_sync.py create mode 100644 splitio/tasks/telemetry_sync.py create mode 100644 splitio/tasks/util/__init__.py rename splitio/{ => tasks/util}/asynctask.py (54%) create mode 100644 splitio/tasks/util/workerpool.py create mode 100644 splitio/tasks/uwsgi_wrappers.py delete mode 100644 splitio/tests/algoSplits.json delete mode 100644 splitio/tests/segmentChanges.json delete mode 100644 splitio/tests/splitChanges.json delete mode 100644 splitio/tests/splitChangesReadOnly.json delete mode 100644 splitio/tests/splitCustomImpressionListener.json delete mode 100644 splitio/tests/splitGetTreatments.json delete mode 100644 splitio/tests/test_api.py delete mode 100644 splitio/tests/test_cache.py delete mode 100644 splitio/tests/test_clients.py delete mode 100644 splitio/tests/test_events.py delete mode 100644 splitio/tests/test_factories.py delete mode 100644 splitio/tests/test_get_treatments.py delete mode 100644 splitio/tests/test_impression_listener.py delete mode 100644 splitio/tests/test_impressions.py delete mode 100644 splitio/tests/test_input_validator.py delete mode 100644 splitio/tests/test_matchers.py delete mode 100644 splitio/tests/test_metrics.py delete mode 100644 splitio/tests/test_prefix_decorator.py delete mode 100644 splitio/tests/test_redis_cache.py delete mode 100644 splitio/tests/test_redis_support.py delete mode 100644 splitio/tests/test_segments.py delete mode 100644 splitio/tests/test_splits.py delete mode 100644 splitio/tests/test_splitters.py delete mode 100644 splitio/tests/test_tasks.py delete mode 100644 splitio/tests/test_transformers.py delete mode 100644 splitio/tests/test_uwsgi.py delete mode 100644 splitio/tests/utils.py delete mode 100644 splitio/transformers.py delete mode 100644 splitio/treatments.py delete mode 100644 splitio/update_scripts/post_impressions.py delete mode 100644 splitio/update_scripts/post_metrics.py delete mode 100644 splitio/update_scripts/update_segments.py delete mode 100644 splitio/update_scripts/update_splits.py delete mode 100644 splitio/utils.py delete mode 100644 splitio/uwsgi.py create mode 100644 tests/api/test_events.py create mode 100644 tests/api/test_httpclient.py create mode 100644 tests/api/test_impressions_api.py create mode 100644 tests/api/test_segments_api.py create mode 100644 tests/api/test_splits_api.py create mode 100644 tests/api/test_telemetry.py create mode 100644 tests/client/test_client.py create mode 100644 tests/client/test_factory.py create mode 100644 tests/client/test_input_validator.py rename {splitio/tests => tests/engine/files}/murmur3-custom-uuids.csv (100%) rename {splitio/tests => tests/engine/files}/murmur3-sample-data-non-alpha-numeric-v2.csv (100%) rename {splitio/tests => tests/engine/files}/murmur3-sample-data-v2.csv (100%) rename {splitio/tests => tests/engine/files}/sample-data-non-alpha-numeric.jsonl (100%) rename {splitio/tests => tests/engine/files}/sample-data.jsonl (100%) create mode 100644 tests/engine/test_evaluator.py create mode 100644 tests/engine/test_hashfns.py create mode 100644 tests/engine/test_splitter.py rename {splitio/tests => tests/models/grammar/files}/regex.txt (100%) create mode 100644 tests/models/grammar/test_conditions.py create mode 100644 tests/models/grammar/test_matchers.py create mode 100644 tests/models/grammar/test_partitions.py create mode 100644 tests/models/test_splits.py create mode 100644 tests/storage/adapters/test_redis_adapter.py create mode 100644 tests/storage/test_inmemory_storage.py create mode 100644 tests/storage/test_redis.py create mode 100644 tests/storage/test_uwsgi.py create mode 100644 tests/tasks/test_events_sync.py create mode 100644 tests/tasks/test_impressions_sync.py create mode 100644 tests/tasks/test_segment_sync.py create mode 100644 tests/tasks/test_split_sync.py create mode 100644 tests/tasks/test_telemetry_sync.py create mode 100644 tests/tasks/util/test_asynctask.py create mode 100644 tests/tasks/util/test_workerpool.py diff --git a/setup.cfg b/setup.cfg index acc2df52..77aa70d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,12 +8,13 @@ description-file = README.md max-line-length=100 exclude=tests/* -[nosetests] -verbosity=1 -detailed-errors=1 -with-coverage=1 -cover-package=splitio -debug=nose.loader +[aliases] +test=pytest + +[tool:pytest] +ignore_glob=./splitio/_OLD/* +addopts = --cov=splitio +python_classes=*Tests [build_sphinx] source-dir = doc/source @@ -21,4 +22,4 @@ build-dir = doc/build all_files = 1 [upload_sphinx] -upload-dir = doc/build/html \ No newline at end of file +upload-dir = doc/build/html diff --git a/setup.py b/setup.py index 39bd0daf..0d497308 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ +"""Setup module.""" #!/usr/bin/env python -from setuptools import setup +from setuptools import setup, find_packages from os import path from sys import version_info -tests_require = ['flake8', 'nose', 'coverage'] +tests_require = ['flake8', 'pytest', 'pytest-mock', 'coverage', 'pytest-cov'] install_requires = [ - 'arrow>=0.7.0', 'requests>=2.9.1', 'future>=0.15.2', 'docopt>=0.6.2', @@ -33,11 +33,11 @@ tests_require=tests_require, extras_require={ 'test': tests_require, - 'redis': ['redis>=2.10.5', 'jsonpickle>=0.9.3'], - 'uwsgi': ['uwsgi>=2.0.0', 'jsonpickle>=0.9.3'], + 'redis': ['redis>=2.10.5'], + 'uwsgi': ['uwsgi>=2.0.0'], 'cpphash': ['mmh3cffi>=0.1.4'] }, - setup_requires=['nose'], + setup_requires=['pytest-runner'], classifiers=[ 'Development Status :: 3 - Alpha', 'Environment :: Console', @@ -47,4 +47,4 @@ 'Programming Language :: Python :: 3', 'Topic :: Software Development :: Libraries' ], - packages=['splitio', 'splitio.update_scripts', 'splitio.bin', 'splitio.hashfns']) + packages=find_packages()) diff --git a/splitio/__init__.py b/splitio/__init__.py index 0dcdeef2..ea266440 100644 --- a/splitio/__init__.py +++ b/splitio/__init__.py @@ -1,31 +1,32 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals -from .factories import get_factory # noqa -from .key import Key # noqa -from .version import __version__ # noqa - -__all__ = ('api', 'brokers', 'cache', 'clients', 'matchers', 'segments', - 'settings', 'splits', 'splitters', 'transformers', 'treatments', - 'version', 'factories', 'manager') - - -# Functions defined to maintain compatibility with previous sdk versions. -# ====================================================================== +from splitio.client.factory import get_factory +#from .factories import get_factory # noqa +#from .key import Key # noqa +#from .version import __version__ # noqa # -# This functions are not supposed to be used directly, factory method should be -# called instead, but since they were previously exposed, they're re-added here -# as helper function so that if someone was using we don't break their code. - -def get_client(apikey, **kwargs): - from .clients import Client - from .brokers import get_self_refreshing_broker - broker = get_self_refreshing_broker(apikey, **kwargs) - return Client(broker) - - -def get_redis_client(apikey, **kwargs): - from .clients import Client - from .brokers import get_redis_broker - broker = get_redis_broker(apikey, **kwargs) - return Client(broker) +#__all__ = ('api', 'brokers', 'cache', 'clients', 'matchers', 'segments', +# 'settings', 'splits', 'splitters', 'transformers', 'treatments', +# 'version', 'factories', 'manager') +# +# +## Functions defined to maintain compatibility with previous sdk versions. +## ====================================================================== +## +## This functions are not supposed to be used directly, factory method should be +## called instead, but since they were previously exposed, they're re-added here +## as helper function so that if someone was using we don't break their code. +# +#def get_client(apikey, **kwargs): +# from .clients import Client +# from .brokers import get_self_refreshing_broker +# broker = get_self_refreshing_broker(apikey, **kwargs) +# return Client(broker) +# +# +#def get_redis_client(apikey, **kwargs): +# from .clients import Client +# from .brokers import get_redis_broker +# broker = get_redis_broker(apikey, **kwargs) +# return Client(broker) diff --git a/splitio/api.py b/splitio/api.py deleted file mode 100644 index c71c4df9..00000000 --- a/splitio/api.py +++ /dev/null @@ -1,263 +0,0 @@ -"""Provides access to the Split.io SDK API""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging -import requests -import json - -from splitio.config import SDK_API_BASE_URL, EVENTS_API_BASE_URL, SDK_VERSION - -_SEGMENT_CHANGES_URL_TEMPLATE = '{base_url}/segmentChanges/{segment_name}/' -_SPLIT_CHANGES_URL_TEMPLATE = '{base_url}/splitChanges/' -_TEST_IMPRESSIONS_URL_TEMPLATE = '{base_url}/testImpressions/bulk/' -_METRICS_URL_TEMPLATE = '{base_url}/metrics/{endpoint}/' -_EVENTS_URL_TEMPLATE = '{base_url}/events/bulk/' - - -class SdkApi(object): - """ - SDK API Class contains all methods required to interact with - the split backend. - """ - def __init__(self, api_key, sdk_api_base_url=None, events_api_base_url=None, - split_sdk_machine_name=None, split_sdk_machine_ip=None, connect_timeout=1500, - read_timeout=1000): - """Provides access to the Split.io SDK RESTful API - - :param api_key: The API key generated on the admin interface - :type api_key: str - :param sdk_api_base_url: An optional string used to override the default API base url. - Useful for testing or to change the target environment. - :type sdk_api_base_url: str - :param events_api_base_url: An optional string used to override the default events API base - url. Useful for testing or to change the target environment. - :type events_api_base_url: str - :param split_sdk_machine_name: An optional value for the SplitSDKMachineName header. It can - be a function instead of a string if it has to be evaluated - at request time - :type split_sdk_machine_name: str - :param split_sdk_machine_ip: An optional value for the SplitSDKMachineIP header. It can be - a function instead of a string if it has to be evaluated at - request time - :type split_sdk_machine_ip: str - :param connect_timeout: The TCP connection timeout. Default: 1500 (seconds) - :type connect_timeout: float - :param read_timeout: The HTTP read timeout. Default: 1000 (seconds) - :type read_timeout: float - """ - self._logger = logging.getLogger(self.__class__.__name__) - self._api_key = api_key - self._sdk_api_url_base = sdk_api_base_url if sdk_api_base_url is not None \ - else SDK_API_BASE_URL - self._events_api_url_base = events_api_base_url if events_api_base_url is not None \ - else EVENTS_API_BASE_URL - self._split_sdk_machine_name = split_sdk_machine_name - self._split_sdk_machine_ip = split_sdk_machine_ip - self._timeout = (connect_timeout, read_timeout) - - def _build_headers(self): - """Builds a dictionary with the standard headers used in all calls the Split.IO SDK. - - The mandatory headers are: - - * Authorization: [String] the api token bearer for the customer using the sdk. - Example: "Authorization: Bearer ". - * SplitSDKVersion: [String] of the shape -. For example, for python sdk - is it "python-0.0.1". - * Accept-Encoding: gzip - - Optionally, the following headers can be included - - * SplitSDKMachineName: [String] name of the machine. - * SplitSDKMachineIP: [String] IP address of the machine - - :return: A dictionary with the headers used in every call to the backend using the values - set during initialization or the defaults set by the settings module. - :rtype: dict - """ - headers = { - 'Authorization': 'Bearer {api_key}'.format(api_key=self._api_key), - 'SplitSDKVersion': SDK_VERSION, - 'Accept-Encoding': 'gzip' - } - - if self._split_sdk_machine_name is not None: - headers['SplitSDKMachineName'] = self._split_sdk_machine_name() \ - if callable(self._split_sdk_machine_name) else self._split_sdk_machine_name - - if self._split_sdk_machine_ip is not None: - headers['SplitSDKMachineIP'] = self._split_sdk_machine_ip() \ - if callable(self._split_sdk_machine_ip) else self._split_sdk_machine_ip - - return headers - - def _logHttpError(self, response): - if response.status_code < requests.codes.ok or response.status_code >= requests.codes.bad: - respJson = response.json() - if 'message' in respJson: - self._logger.error( - "HTTP Error (status: %s) connecting with split servers: %s" - % (response.status_code, respJson['message']) - ) - else: - self._logger.error("HTTP Error connecting with split servers") - - def _get(self, url, params): - headers = self._build_headers() - - response = requests.get(url, params=params, headers=headers, timeout=self._timeout) - self._logHttpError(response) - - return response.json() - - def _post(self, url, data): - headers = self._build_headers() - - response = requests.post(url, json=data, headers=headers, timeout=self._timeout) - self._logHttpError(response) - - return response.status_code - - def split_changes(self, since): - """Makes a request to the splitChanges endpoint. - :param since: An integer that indicates when was the endpoint last called. It is usually - either -1, which returns the value of the split, or the value of the field - "till" of the response of a previous call, which will only return the changes - since that call. - :type since: int - :return: Changes seen on splits - :rtype: dict - """ - url = _SPLIT_CHANGES_URL_TEMPLATE.format(base_url=self._sdk_api_url_base) - params = { - 'since': since - } - - return self._get(url, params) - - def segment_changes(self, segment_name, since): - """Makes a request to the segmentChanges endpoint. - :param segment_name: Name of the segment - :type since: str - :param since: An integer that indicates when was the endpoint last called. It is usually - either -1, which returns the value of the split, or the value of the field - "till" of the response of a previous call, which will only return the changes - since that call. - :type since: int - :return: Changes seen on segments - :rtype: dict - """ - url = _SEGMENT_CHANGES_URL_TEMPLATE.format(base_url=self._sdk_api_url_base, - segment_name=segment_name) - params = { - 'since': since - } - - return self._get(url, params) - - def test_impressions(self, test_impressions_data): - """Makes a request to the testImpressions endpoint. The method takes a dictionary with the - test (feature) name and a list of impressions: - - [ - { - "testName": str, # name of the test (feature), - "impressions": [ - { - "keyName" : str, # name of the key that saw this feature - "treatment" : str, # the treatment e.g. "on" or "off" - "time" : int # the timestamp (in ms) when this happened. - }, - ... - ] - } - ] - - :param test_impressions_data: Data of the impressions of a test (feature) - :type test_impressions_data: list - """ - url = _TEST_IMPRESSIONS_URL_TEMPLATE.format(base_url=self._events_api_url_base) - return self._post(url, test_impressions_data) - - def metrics_times(self, times_data): - """Makes a request to the times metrics endpoint. The method takes a list of dictionaries - with the latencies seen for each metric: - - [ - { - "name": str, # name of the metric - "latencies": [int, int, int, ...] # latencies seen - }, - { - "name": str, # name of the metric - "latencies": [int, int, int, ...] # latencies seen - }, - ... - ] - :param times_data: Data for the metrics times - :type times_data: list - """ - url = _METRICS_URL_TEMPLATE.format(base_url=self._events_api_url_base, endpoint='times') - return self._post(url, times_data) - - def metrics_counters(self, counters_data): - """Makes a request to the counters metrics endpoint. The method takes a list of dictionaries - with the deltas for the counts for each metric: - - [ - { - "name": str, # name of the metric - "delta": int # count delta - }, - { - "name": str, # name of the metric - "delta": int # count delta - }, - ... - ] - :param counters_data: Data for the metrics counters - :type counters_data: list - """ - url = _METRICS_URL_TEMPLATE.format(base_url=self._events_api_url_base, endpoint='counters') - return self._post(url, counters_data) - - def metrics_gauge(self, gauge_data): - """Makes a request to the gauge metrics endpoint. The method takes a list of dictionaries - with the values for the gauge of each metric: - - [ - { - "name": str, # name of the metric - "value": float # gauge value - }, - { - "name": str, # name of the metric - "value": float # gauge value - }, - ... - ] - :param gauge_data: Data for the metrics gauge - :type gauge_data: list - """ - url = _METRICS_URL_TEMPLATE.format(base_url=self._events_api_url_base, endpoint='gauge') - return self._post(url, gauge_data) - - def track_events(self, events): - url = _EVENTS_URL_TEMPLATE.format(base_url=self._events_api_url_base) - return self._post(url, events) - - -def api_factory(config): - """Build a split.io SDK API client using a config dictionary. - :param config: A config dictionary - :type config: dict - :return: SdkApi client - :rtype: SdkApi - """ - return SdkApi(config.get('apiKey'), - sdk_api_base_url=config['sdkApiBaseUrl'], - events_api_base_url=config['eventsApiBaseUrl'], - split_sdk_machine_name=config['splitSdkMachineName'], - split_sdk_machine_ip=config['splitSdkMachineIp'], - connect_timeout=config['connectionTimeout'], - read_timeout=config['readTimeout']) diff --git a/splitio/api/__init__.py b/splitio/api/__init__.py new file mode 100644 index 00000000..38a7d568 --- /dev/null +++ b/splitio/api/__init__.py @@ -0,0 +1,21 @@ +"""Split API module.""" + +class APIException(Exception): + """Exception to raise when an API call fails.""" + + def __init__(self, custom_message, status_code=None, original_exception=None): + """Constructor.""" + Exception.__init__(self, custom_message) + self._status_code = status_code if status_code else -1 + self._custom_message = custom_message + self._original_exception = original_exception + + @property + def status_code(self): + """Return HTTP status code.""" + return self._status_code + + @property + def custom_message(self): + """Return custom message.""" + return self._custom_message diff --git a/splitio/api/client.py b/splitio/api/client.py new file mode 100644 index 00000000..3a9e7ca8 --- /dev/null +++ b/splitio/api/client.py @@ -0,0 +1,141 @@ +"""Synchronous HTTP Client for split API.""" +from __future__ import division + +from collections import namedtuple +import requests + +HttpResponse = namedtuple('HttpResponse', ['status_code', 'body']) + + +class HttpClientException(Exception): + """HTTP Client exception.""" + + def __init__(self, custom_message, original_exception=None): + """ + Class constructor. + + :param message: Information on why this exception happened. + :type message: str + :param original_exception: Original exception being caught if any. + :type original_exception: Exception. + """ + Exception.__init__(self, custom_message) + self._custom_message = custom_message + self._original_exception = original_exception + + @property + def custom_message(self): + """Return custom message.""" + return self._custom_message + + @property + def original_exception(self): + """Return original exception.""" + return self._original_exception + + +class HttpClient(object): + """HttpClient wrapper.""" + + SDK_URL = 'https://split.io/api' + EVENTS_URL = 'https://split.io/api' + + def __init__(self, timeout=None, sdk_url=None, events_url=None): + """ + Class constructor. + + :param sdk_url: Optional alternative sdk URL. + :type sdk_url: str + :param events_url: Optional alternative events URL. + :type events_url: str + :param timeout: How many milliseconds to wait until the server responds. + :type timeout: int + """ + self._timeout = timeout / 1000 if timeout else None # Convert ms to seconds. + self._urls = { + 'sdk': sdk_url if sdk_url is not None else self.SDK_URL, + 'events': events_url if events_url is not None else self.EVENTS_URL, + } + + def _build_url(self, server, path): + """ + Build URL according to server specified. + + :param server: Server for whith the request is being made. + :type server: str + :param path: URL path to be appended to base host. + :type path: str + + :return: A fully qualified URL. + :rtype: str + """ + return self._urls[server] + path + + def get(self, server, path, apikey, query=None, extra_headers=None): #pylint: disable=too-many-arguments + """ + Issue a get request. + + :param path: path to append to the host url. + :type path: str + :param apikey: api token. + :type apikey: str + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + headers = { + 'Content-Type': 'application/json', + 'Authorization': "Bearer %s" % apikey + } + + if extra_headers is not None: + headers.update(extra_headers) + + try: + response = requests.get( + self._build_url(server, path), + params=query, + headers=headers, + timeout=self._timeout + ) + return HttpResponse(response.status_code, response.text) + except Exception as exc: + raise HttpClientException('requests library is throwing exceptions', exc) + + def post(self, server, path, apikey, body, query=None, extra_headers=None): #pylint: disable=too-many-arguments + """ + Issue a POST request. + + :param path: path to append to the host url. + :type path: str + :param apikey: api token. + :type apikey: str + :param body: body sent in the request. + :type body: str + :param extra_headers: key/value pairs of possible extra headers. + :type extra_headers: dict + + :return: Tuple of status_code & response text + :rtype: HttpResponse + """ + headers = { + 'Content-Type': 'application/json', + 'Authorization': "Bearer %s" % apikey + } + + if extra_headers is not None: + headers.update(extra_headers) + + try: + response = requests.post( + self._build_url(server, path), + json=body, + params=query, + headers=headers, + timeout=self._timeout + ) + return HttpResponse(response.status_code, response.text) + except Exception as exc: + raise HttpClientException('requests library is throwing exceptions', exc) diff --git a/splitio/api/events.py b/splitio/api/events.py new file mode 100644 index 00000000..e2af3d00 --- /dev/null +++ b/splitio/api/events.py @@ -0,0 +1,73 @@ +"""Events API module.""" +import logging +from splitio.api import APIException +from splitio.api.client import HttpClientException + + +class EventsAPI(object): #pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the events API.""" + + def __init__(self, http_client, apikey, sdk_metadata): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param apikey: User apikey token. + :type apikey: string + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._client = http_client + self._apikey = apikey + self._metadata = { + 'SplitSDKVersion': sdk_metadata.sdk_version, + 'SplitSDKMachineIP': sdk_metadata.instance_ip, + 'SplitSDKMachineName': sdk_metadata.instance_name + } + + @staticmethod + def _build_bulk(events): + """ + Build event bulk as expected by the API. + + :param events: Events to be bundled. + :type events: list(splitio.models.events.Event) + + :return: Formatted bulk. + :rtype: dict + """ + return [ + { + 'key': event.key, + 'trafficTypeName': event.traffic_type_name, + 'eventTypeId': event.event_type_id, + 'value': event.value, + 'timestamp': event.timestamp + } + for event in events + ] + + def flush_events(self, events): + """ + Send events to the backend. + + :param events: Events bulk + :type events: list + + :return: True if flush was successful. False otherwise + :rtype: bool + """ + bulk = self._build_bulk(events) + try: + response = self._client.post( + 'events', + '/events/bulk', + self._apikey, + body=bulk, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + self._logger.debug('Error flushing events: ', exc_info=True) + raise APIException(exc.custom_message, original_exception=exc.original_exception) diff --git a/splitio/api/impressions.py b/splitio/api/impressions.py new file mode 100644 index 00000000..95c625c0 --- /dev/null +++ b/splitio/api/impressions.py @@ -0,0 +1,82 @@ +"""Impressions API module.""" + +import logging +from itertools import groupby +from splitio.api import APIException +from splitio.api.client import HttpClientException + + +class ImpressionsAPI(object): # pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the impressions API.""" + + def __init__(self, client, apikey, sdk_metadata): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param apikey: User apikey token. + :type apikey: string + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._client = client + self._apikey = apikey + self._metadata = { + 'SplitSDKVersion': sdk_metadata.sdk_version, + 'SplitSDKMachineIP': sdk_metadata.instance_ip, + 'SplitSDKMachineName': sdk_metadata.instance_name + } + + @staticmethod + def _build_bulk(impressions): + """ + Build an impression bulk formatted as the API expects it. + + :param impressions: List of impressions to bundle. + :type impressions: list(splitio.models.impressions.Impression) + + :return: Dictionary of lists of impressions. + :rtype: dict + """ + return [ + { + 'testName': group[0], + 'keyImpressions': [ + { + 'keyName': impression.matching_key, + 'treatment': impression.treatment, + 'time': impression.time, + 'changeNumber': impression.change_number, + 'label': impression.label, + 'bucketingKey': impression.bucketing_key + } + for impression in group[1] + ] + } + for group in groupby( + sorted(impressions, key=lambda i: i.feature_name), + lambda i: i.feature_name + ) + ] + + def flush_impressions(self, impressions): + """ + Send impressions to the backend. + + :param impressions: Impressions bulk + :type impressions: list + """ + bulk = self._build_bulk(impressions) + try: + response = self._client.post( + 'events', + '/testImpressions/bulk', + self._apikey, + body=bulk, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + self._logger.debug('Error flushing events: ', exc_info=True) + raise APIException(exc.custom_message, original_exception=exc.original_exception) diff --git a/splitio/api/segments.py b/splitio/api/segments.py new file mode 100644 index 00000000..c6533827 --- /dev/null +++ b/splitio/api/segments.py @@ -0,0 +1,51 @@ +"""Segments API module.""" + +import json +import logging +from splitio.api import APIException +from splitio.api.client import HttpClientException + + +class SegmentsAPI(object): #pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the segments API.""" + + def __init__(self, http_client, apikey): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: client.HttpClient + :param apikey: User apikey token. + :type apikey: string + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._client = http_client + self._apikey = apikey + + def fetch_segment(self, segment_name, change_number): + """ + Fetch splits from backend. + + :param segment_name: Name of the segment to fetch changes for. + :type segment_name: str + :param change_number: Last known timestamp of a split modification. + :type change_number: int + + :return: Json representation of a segmentChange response. + :rtype: dict + """ + try: + response = self._client.get( + 'sdk', + '/segmentChanges/{segment_name}'.format(segment_name=segment_name), + self._apikey, + {'since': change_number} + ) + + if 200 <= response.status_code < 300: + return json.loads(response.body) + else: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + self._logger.debug('Error flushing events: ', exc_info=True) + raise APIException(exc.custom_message, original_exception=exc.original_exception) diff --git a/splitio/api/splits.py b/splitio/api/splits.py new file mode 100644 index 00000000..aca15b59 --- /dev/null +++ b/splitio/api/splits.py @@ -0,0 +1,48 @@ +"""Splits API module.""" + +import logging +import json +from splitio.api import APIException +from splitio.api.client import HttpClientException + + +class SplitsAPI(object): #pylint: disable=too-few-public-methods + """Class that uses an httpClient to communicate with the splits API.""" + + def __init__(self, client, apikey): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param apikey: User apikey token. + :type apikey: string + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._client = client + self._apikey = apikey + + def fetch_splits(self, change_number): + """ + Fetch splits from backend. + + :param changeNumber: Last known timestamp of a split modification. + :type changeNumber: int + + :return: Json representation of a splitChanges response. + :rtype: dict + """ + try: + response = self._client.get( + 'sdk', + '/splitChanges', + self._apikey, + {'since': change_number} + ) + if 200 <= response.status_code < 300: + return json.loads(response.body) + else: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + self._logger.debug('Error flushing events: ', exc_info=True) + raise APIException(exc.custom_message, original_exception=exc.original_exception) diff --git a/splitio/api/telemetry.py b/splitio/api/telemetry.py new file mode 100644 index 00000000..a6f88700 --- /dev/null +++ b/splitio/api/telemetry.py @@ -0,0 +1,132 @@ +"""Telemetry API Module.""" +import logging +import six +from splitio.api import APIException +from splitio.api.client import HttpClientException + + +class TelemetryAPI(object): + """Class to handle telemetry submission to the backend.""" + + def __init__(self, client, apikey, sdk_metadata): + """ + Class constructor. + + :param client: HTTP Client responsble for issuing calls to the backend. + :type client: HttpClient + :param apikey: User apikey token. + :type apikey: string + :param sdk_metadata: SDK Version, IP & Machine name + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._client = client + self._apikey = apikey + self._metadata = { + 'SplitSDKVersion': sdk_metadata.sdk_version, + 'SplitSDKMachineIP': sdk_metadata.instance_ip, + 'SplitSDKMachineName': sdk_metadata.instance_name + } + + @staticmethod + def _build_latencies(latencies): + """ + Build a latencies bulk as expected by the BE. + + :param latencies: Latencies to bundle. + :type latencies: dict + """ + return [ + {'name': name, 'latencies': latencies_list} + for name, latencies_list in six.iteritems(latencies) + ] + + def flush_latencies(self, latencies): + """ + Submit latencies to the backend. + + :param latencies: List of latency buckets with their respective count. + :type latencies: list + """ + bulk = self._build_latencies(latencies) + try: + response = self._client.post( + 'events', + '/metrics/times', + self._apikey, + body=bulk, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + raise APIException(exc.custom_message, original_exception=exc.original_exception) + + @staticmethod + def _build_gauges(gauges): + """ + Build a gauges bulk as expected by the BE. + + :param gauges: Gauges to bundle. + :type gauges: dict + """ + return [ + {'name': name, 'value': value} + for name, value in six.iteritems(gauges) + ] + + def flush_gauges(self, gauges): + """ + Submit gauges to the backend. + + :param gauges: Gauges measured to be sent to the backend. + :type gauges: List + """ + bulk = self._build_gauges(gauges) + try: + response = self._client.post( + 'events', + '/metrics/gauge', + self._apikey, + body=bulk, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + raise APIException(exc.custom_message, original_exception=exc.original_exception) + + @staticmethod + def _build_counters(counters): + """ + Build a counters bulk as expected by the BE. + + :param counters: Counters to bundle. + :type counters: dict + """ + return [ + {'name': name, 'delta': value} + for name, value in six.iteritems(counters) + ] + + def flush_counters(self, counters): + """ + Submit counters to the backend. + + :param counters: Counters measured to be sent to the backend. + :type counters: List + """ + bulk = self._build_counters(counters) + try: + response = self._client.post( + 'events', + '/metrics/counters', + self._apikey, + body=bulk, + extra_headers=self._metadata + ) + if not 200 <= response.status_code < 300: + raise APIException(response.body, response.status_code) + except HttpClientException as exc: + self._logger.debug('Error flushing events: ', exc_info=True) + raise APIException(exc.custom_message, original_exception=exc.original_exception) diff --git a/splitio/bin/synchronizer.py b/splitio/bin/synchronizer.py deleted file mode 100644 index 4e0add2c..00000000 --- a/splitio/bin/synchronizer.py +++ /dev/null @@ -1,133 +0,0 @@ -""" - __ ____ _ _ _ - / /__ / ___| _ __ | (_) |_ - / / \ \ \___ \| '_ \| | | __| - \ \ \ \ ___) | |_) | | | |_ - \_\ / / |____/| .__/|_|_|\__| - /_/ |_| - -Split.io Synchronizer Service. - -Usage: - synchronizer [options] - synchronizer -h | --help - synchronizer --version - -Options: - --splits-refresh-rate=SECONDS The SECONDS rate to fetch Splits definitions [default: 30] - --segments-refresh-rate=SECONDS The SECONDS rate to fetch the Segments keys [default: 30] - --impression-refresh-rate=SECONDS The SECONDS rate to send key impressions [default: 60] - --metrics-refresh-rate=SECONDS The SECONDS rate to send SDK metrics [default: 60] - -h --help Show this screen. - --version Show version. - -Configuration file: - The configuration file is a JSON file with the following fields: - - { - "apiKey": "YOUR_API_KEY", - "redisHost": "REDIS_DNS_OR_IP", - "redisPort": 6379, - "redisDb": 0 - } - - -Examples: - python -m splitio.bin.synchronizer splitio-config.json - python -m splitio.bin.synchronizer --splits-refresh-rate=10 splitio-config.json - - -""" - -from splitio.api import api_factory -from splitio.config import SDK_VERSION, parse_config_file -from splitio.redis_support import get_redis, RedisSplitCache, RedisSegmentCache, RedisSplitParser, RedisImpressionsCache, RedisMetricsCache -from splitio.splits import ApiSplitChangeFetcher -from splitio.tasks import update_splits, update_segments, report_metrics, report_impressions -from splitio.segments import ApiSegmentChangeFetcher - -import time -import threading -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('splitio.bin.synchronizer') - - -def _update_splits(seconds, config): - try: - while True: - redis = get_redis(config) - split_cache = RedisSplitCache(redis) - sdk_api = api_factory(config) - split_change_fetcher = ApiSplitChangeFetcher(sdk_api) - segment_cache = RedisSegmentCache(redis) - split_parser = RedisSplitParser(segment_cache) - update_splits(split_cache, split_change_fetcher, split_parser) - - time.sleep(seconds) - except: - logger.exception('Exception caught updating splits') - -def _update_segments(seconds, config): - try: - while True: - redis = get_redis(config) - segment_cache = RedisSegmentCache(redis) - sdk_api = api_factory(config) - segment_change_fetcher = ApiSegmentChangeFetcher(sdk_api) - update_segments(segment_cache, segment_change_fetcher) - - time.sleep(seconds) - except: - logger.exception('Exception caught updating segments') - -def _report_impressions(seconds, config): - try: - while True: - redis = get_redis(config) - impressions_cache = RedisImpressionsCache(redis) - sdk_api = api_factory(config) - report_impressions(impressions_cache, sdk_api) - - time.sleep(seconds) - except: - logger.exception('Exception caught posting impressions') - -def _report_metrics(seconds, config): - try: - while True: - redis = get_redis(config) - metrics_cache = RedisMetricsCache(redis) - sdk_api = api_factory(config) - report_metrics(metrics_cache, sdk_api) - - time.sleep(seconds) - except: - logger.exception('Exception caught posting metrics') - -def run(arguments): - config = parse_config_file(arguments['']) - - update_splits_thread = threading.Thread(target=_update_splits, - args=(int(arguments['--splits-refresh-rate']), config)) - - update_segments_thread = threading.Thread(target=_update_segments, - args=(int(arguments['--segments-refresh-rate']), config)) - - report_impressions_thread = threading.Thread(target=_report_impressions, - args=(int(arguments['--impression-refresh-rate']), config)) - - report_metrics_thread = threading.Thread(target=_report_metrics, - args=(int(arguments['--metrics-refresh-rate']), config)) - - update_splits_thread.start() - update_segments_thread.start() - report_impressions_thread.start() - report_metrics_thread.start() - - -if __name__ == '__main__': - from docopt import docopt - arguments = docopt(__doc__, version=SDK_VERSION) - run(arguments) \ No newline at end of file diff --git a/splitio/brokers.py b/splitio/brokers.py deleted file mode 100644 index 2e7c95d7..00000000 --- a/splitio/brokers.py +++ /dev/null @@ -1,848 +0,0 @@ -"""A module for Split.io SDK Brokers""" -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -import abc -import logging -import os -import os.path -import random -import re -import threading -import time -from future.utils import raise_from - - -from splitio.api import SdkApi -from splitio.exceptions import TimeoutException -from splitio.metrics import Metrics, AsyncMetrics, ApiMetrics, \ - CacheBasedMetrics -from splitio.impressions import TreatmentLog, AsyncTreatmentLog, \ - SelfUpdatingTreatmentLog, CacheBasedTreatmentLog -from splitio.redis_support import RedisSplitCache, RedisImpressionsCache, \ - RedisMetricsCache, get_redis, RedisEventsCache -from splitio.splits import SelfRefreshingSplitFetcher, SplitParser, \ - ApiSplitChangeFetcher, JSONFileSplitFetcher, InMemorySplitFetcher, \ - AllKeysSplit, CacheBasedSplitFetcher -from splitio.segments import ApiSegmentChangeFetcher, \ - SelfRefreshingSegmentFetcher, JSONFileSegmentFetcher -from splitio.config import DEFAULT_CONFIG, MAX_INTERVAL, parse_config_file, \ - set_machine_ip, set_machine_name -from splitio.uwsgi import UWSGISplitCache, UWSGIImpressionsCache, \ - UWSGIMetricsCache, UWSGIEventsCache, get_uwsgi -from splitio.tasks import EventsSyncTask -from splitio.events import InMemoryEventStorage - - -def randomize_interval(value): - """ - Generates a function that return a random integer in the [value/2,value) - interval. The minimum generated value is 5. - :param value: The maximum value for the random interval - :type value: int - :return: A function that returns a random integer in the interval. - :rtype: function - """ - def random_interval(): - return max(5, random.randint(value // 2, value)) - - return random_interval - - -class BaseBroker(object): - """ - Abstract class defining the interface that concrete brokers must implement, - and including methods that use that interface to retrieve splits, log - impressions and submit metrics. - """ - - __metaclass__ = abc.ABCMeta - - def __init__(self, config=None): - """ - Class constructor, only sets up the logger - """ - self._logger = logging.getLogger(self.__class__.__name__) - self._destroyed = False - self._config = config - - def fetch_feature(self, name): - """ - Fetch a feature - :return: The split associated with that feature - :rtype: Split - """ - return self.get_split_fetcher().fetch(name) - - def get_change_number(self): - """ - Returns the latest change number - """ - return self.get_split_fetcher().change_number - - def log_impressions(self, impressions): - """ - Logs impressions after a get_treatments call - :return: The treatment log implementation. - :rtype: TreatmentLog - """ - return self.get_impression_log().log_impressions(impressions) - - def log_operation_time(self, operation, elapsed): - """Get the metrics implementation. - :return: The metrics implementation. - :rtype: Metrics - """ - return self.get_metrics_handler().time(operation, elapsed) - - def log_event(self, event): - """ - Logs an event after a .track() call - """ - return self.get_events_log().log(event) - - @abc.abstractmethod - def get_split_fetcher(self): - pass - - @abc.abstractmethod - def get_metrics_handler(self): - pass - - @abc.abstractmethod - def get_impression_log(self): - pass - - @abc.abstractmethod - def get_events_log(self): - pass - - @abc.abstractmethod - def destroy(self): - pass - - -class JSONFileBroker(BaseBroker): - def __init__(self, config, segment_changes_file_name, split_changes_file_name): - """ - A Broker implementation that uses responses from the segmentChanges and - splitChanges resources to provide access to splits. It is intended to be - used on integration tests only. - - :param segment_changes_file_name: The name of the file with the - segmentChanges response - :type segment_changes_file_name: str - :param split_changes_file_name: The name of the file with the - splitChanges response - :type split_changes_file_name: str - """ - super(JSONFileBroker, self).__init__(config) - self._segment_changes_file_name = segment_changes_file_name - self._split_changes_file_name = split_changes_file_name - self._split_fetcher = self._build_split_fetcher() - self._treatment_log = TreatmentLog() # Does nothing on ._log() - self._metrics = Metrics() # Does nothing on .count(), .time(), .gauge() - - def _build_split_fetcher(self): - """ - Build the json backed split fetcher - :return: The json backed split fetcher - :rtype: JSONFileSplitFetcher - """ - segment_fetcher = JSONFileSegmentFetcher(self._segment_changes_file_name) - split_parser = SplitParser(segment_fetcher) - split_fetcher = JSONFileSplitFetcher( - self._split_changes_file_name, - split_parser - ) - - return split_fetcher - - def get_split_fetcher(self): - """ - Get the split fetcher implementation for the broker. - :return: The split fetcher - :rtype: SplitFetcher - """ - return self._split_fetcher - - def get_metrics_handler(self): - """ - """ - return self._metrics - - def get_impression_log(self): - """ - """ - return self._treatment_log - - def destroy(self): - self._destroyed = True - self._split_fetcher.destroy() - self._treatment_log.destroy() - self._metrics.destroy() - - def get_events_log(self): - return None - - -class SelfRefreshingBroker(BaseBroker): - def __init__(self, api_key, config=None, sdk_api_base_url=None, - events_api_base_url=None, impression_listener=None): - """ - A Broker implementation that refreshes itself at regular intervals. - The config parameter is a dictionary that allows you to control the - behaviour of the broker. - The following configuration values are supported: - * connectionTimeout: The TCP connection timeout (Default: 1500ms) - * readTimeout: The HTTP read timeout (Default: 1500ms) - * featuresRefreshRate: The refresh rate for features (Default: 30s) - * segmentsRefreshRate: The refresh rate for segments (Default: 60s) - * metricsRefreshRate: The refresh rate for metrics (Default: 60s) - * impressionsRefreshRate: The refresh rate for impressions - (Default: 60s) - * randomizeIntervals: Whether to randomize the refres intervals - (Default: False) - * ready: How long to wait (in seconds) for the broker to be initialized. - 0 to return immediately without waiting. (Default: 0s) - - :param api_key: The API key provided by Split.io - :type api_key: str - :param config: The configuration dictionary - :type config: dict - :param sdk_api_base_url: An override for the default API base URL. - :type sdk_api_base_url: str - :param events_api_base_url: An override for the default events base URL. - :type events_api_base_url: str - """ - super(SelfRefreshingBroker, self).__init__() - self._api_key = api_key - self._sdk_api_base_url = sdk_api_base_url - self._events_api_base_url = events_api_base_url - self._impression_listener = impression_listener - - self._init_config(config) - self._sdk_api = self._build_sdk_api() - self._split_fetcher = self._build_split_fetcher() - self._treatment_log = self._build_treatment_log() - self._metrics = self._build_metrics() - self._start() - - self._events_storage = InMemoryEventStorage(self._config['eventsQueueSize']) - self._events_task = EventsSyncTask( - self._sdk_api, - self._events_storage, - self._config['eventsPushRate'], - self._config['eventsQueueSize'], - ) - self._events_storage.set_queue_full_hook(lambda: self._events_task.flush()) - self._events_task.start() - - def _init_config(self, config=None): - self._config = dict(DEFAULT_CONFIG) - if config is not None: - self._config.update(config) - - segment_fetcher_interval = min( - MAX_INTERVAL, - self._config['segmentsRefreshRate'] - ) - split_fetcher_interval = min( - MAX_INTERVAL, - self._config['featuresRefreshRate'] - ) - impressions_interval = min( - MAX_INTERVAL, - self._config['impressionsRefreshRate'] - ) - - if self._config['randomizeIntervals']: - self._segment_fetcher_interval = randomize_interval(segment_fetcher_interval) - self._split_fetcher_interval = randomize_interval(split_fetcher_interval) - self._impressions_interval = randomize_interval(impressions_interval) - else: - self._segment_fetcher_interval = segment_fetcher_interval - self._split_fetcher_interval = split_fetcher_interval - self._impressions_interval = impressions_interval - self._metrics_max_time_between_calls = min( - MAX_INTERVAL, - self._config['metricsRefreshRate'] - ) - self._metrics_max_call_count = self._config['maxMetricsCallsBeforeFlush'] - - self._connection_timeout = self._config['connectionTimeout'] - self._read_timeout = self._config['readTimeout'] - self._max_impressions_log_size = self._config['maxImpressionsLogSize'] - self._ready = self._config['ready'] - - def _build_sdk_api(self): - return SdkApi( - self._api_key, - sdk_api_base_url=self._sdk_api_base_url, - events_api_base_url=self._events_api_base_url, - connect_timeout=self._connection_timeout, - read_timeout=self._read_timeout - ) - - def _build_split_fetcher(self): - """ - Build the self refreshing split fetcher - :return: The self refreshing split fetcher - :rtype: SelfRefreshingSplitFetcher - """ - segment_change_fetcher = ApiSegmentChangeFetcher(self._sdk_api) - segment_fetcher = SelfRefreshingSegmentFetcher( - segment_change_fetcher, - interval=self._segment_fetcher_interval - ) - split_change_fetcher = ApiSplitChangeFetcher(self._sdk_api) - split_parser = SplitParser(segment_fetcher) - split_fetcher = SelfRefreshingSplitFetcher( - split_change_fetcher, - split_parser, - interval=self._split_fetcher_interval - ) - return split_fetcher - - def _build_treatment_log(self): - """Build the treatment log implementation. - :return: The treatment log implementation. - :rtype: TreatmentLog - """ - self_updating_treatment_log = SelfUpdatingTreatmentLog( - self._sdk_api, - max_count=self._max_impressions_log_size, - interval=self._impressions_interval, - ) - return AsyncTreatmentLog(self_updating_treatment_log) - - def _build_metrics(self): - """Build the metrics implementation. - :return: The metrics implementation. - :rtype: Metrics - """ - api_metrics = ApiMetrics( - self._sdk_api, - max_call_count=self._metrics_max_call_count, - max_time_between_calls=self._metrics_max_time_between_calls - ) - return AsyncMetrics(api_metrics) - - def _start(self): - self._treatment_log.delegate.start() - - if self._ready > 0: - event = threading.Event() - - thread = threading.Thread(target=self._fetch_splits, args=(event,)) - thread.daemon = True - thread.start() - - flag_set = event.wait(self._ready / 1000) - if not flag_set: - self._logger.info( - 'Timeout reached. Returning broker in partial state.' - ) - raise TimeoutException() - else: - self._split_fetcher.start() - - def _fetch_splits(self, event): - """ - Fetches the split and segment information blocking until it is done. - """ - self._split_fetcher.refresh_splits(block_until_ready=True) - self._split_fetcher.start(delayed_update=True) - event.set() - - def get_split_fetcher(self): - """ - Get the split fetcher implementation for the broker. - :return: The split fetcher - :rtype: SplitFetcher - """ - return self._split_fetcher - - def get_metrics_handler(self): - return self._metrics - - def get_impression_log(self): - return self._treatment_log - - def get_events_log(self): - return self._events_storage - - def destroy(self): - self._destroyed = True - self._split_fetcher.destroy() - self._treatment_log.destroy() - self._metrics.destroy() - self._events_task.stop() - - -class LocalhostBroker(BaseBroker): - _COMMENT_LINE_RE = re.compile('^#.*$') - _DEFINITION_LINE_RE = re.compile( - '^(?[\w_-]+)\s+(?P[\w_-]+)$' - ) - - class LocalhostEventStorage(object): - def log(self, event): - pass - - def __init__(self, config, split_definition_file_name=None, auto_refresh_period=2): - """ - A broker implementation that builds its configuration from a split - definition file. By default the definition is taken from $HOME/.split - but the file name can be supplied as argument as well. - :param split_definition_file_name: Name of definition file (Optional) - :type split_definition_file_name: str - :param auto_refresh_period: Number of seconds between split refresh calls - :type auto_refresh_period: int - """ - super(LocalhostBroker, self).__init__(config) - - if split_definition_file_name is None: - self._split_definition_file_name = os.path.join( - os.path.expanduser('~'), '.split' - ) - else: - self._split_definition_file_name = split_definition_file_name - - self._split_refresh_period = auto_refresh_period - - self._split_fetcher = self._build_split_fetcher() - self._refresh_thread = threading.Thread(target=self.refresh_splits) - self._refresh_thread.daemon = True - self._refresh_thread.start() - - self._treatment_log = TreatmentLog() - self._metrics = Metrics() - self._event_storage = LocalhostBroker.LocalhostEventStorage() - - def _build_split_fetcher(self): - """ - Build the in memory split fetcher using the local environment split - definition file - :return: The in memory split fetcher - :rtype: InMemorySplitFetcher - """ - splits = self._parse_split_file(self._split_definition_file_name) - split_fetcher = InMemorySplitFetcher(splits=splits) - - return split_fetcher - - def _parse_split_file(self, file_name): - splits = dict() - - try: - with open(file_name) as f: - for line in f: - if line.strip() == '': - continue - - comment_match = LocalhostBroker._COMMENT_LINE_RE.match(line) - if comment_match: - continue - - definition_match = LocalhostBroker._DEFINITION_LINE_RE.match(line) - if definition_match: - splits[definition_match.group('feature')] = AllKeysSplit( - definition_match.group('feature'), - definition_match.group('treatment') - ) - continue - - self._logger.warning( - 'Invalid line on localhost environment split ' - 'definition. Line = %s', - line - ) - return splits - except IOError as e: - raise_from(ValueError( - 'There was a problem with ' - 'the splits definition file "{}"'.format(file_name)), - e - ) - - def get_split_fetcher(self): - """ - Get the split fetcher implementation for the broker. - :return: The split fetcher - :rtype: SplitFetcher - """ - return self._split_fetcher - - def get_metrics_handler(self): - """ - """ - return self._metrics - - def get_impression_log(self): - """ - """ - return self._treatment_log - - def get_events_log(self): - return self._event_storage - - def refresh_splits(self): - while not self._destroyed: - time.sleep(self._split_refresh_period) - if not self._destroyed: # DO NOT REMOVE - # This check is used in case the client was - # destroyed while the thread was sleeping - # and the file was closed, in order to - # prevent an exception. - self._split_fetcher = self._build_split_fetcher() - - def destroy(self): - self._destroyed = True - self._split_fetcher.destroy() - self._treatment_log.destroy() - self._metrics.destroy() - - -class RedisBroker(BaseBroker): - def __init__(self, redis, config): - """A Broker implementation that uses Redis as its backend. - :param redis: A redis broker - :type redis: StrctRedis""" - super(RedisBroker, self).__init__(config) - - split_cache = RedisSplitCache(redis) - split_fetcher = CacheBasedSplitFetcher(split_cache) - - impressions_cache = RedisImpressionsCache(redis) - treatment_log = CacheBasedTreatmentLog(impressions_cache) - - metrics_cache = RedisMetricsCache(redis) - delegate_metrics = CacheBasedMetrics(metrics_cache) - metrics = AsyncMetrics(delegate_metrics) - - self._split_fetcher = split_fetcher - self._treatment_log = treatment_log - self._metrics = metrics - - self._event_storage = RedisEventsCache(redis) - - def get_split_fetcher(self): - """ - Get the split fetcher implementation for the broker. - :return: The split fetcher - :rtype: SplitFetcher - """ - return self._split_fetcher - - def get_metrics_handler(self): - """ - """ - return self._metrics - - def get_impression_log(self): - """ - """ - return self._treatment_log - - def get_events_log(self): - return self._event_storage - - def get_metrics(self): - """ - Get the metrics implementation for the broker. - :return: The metrics - :rtype: Metrics - """ - return self._metrics - - def destroy(self): - self._destroyed = True - self._split_fetcher.destroy() - self._treatment_log.destroy() - self._metrics.destroy() - - -class UWSGIBroker(BaseBroker): - def __init__(self, uwsgi, config=None): - """ - A Broker implementation that consumes data from uwsgi cache framework. - The config parameter is a dictionary that allows you to control the - behaviour of the broker. - - :param config: The configuration dictionary - :type config: dict - """ - super(UWSGIBroker, self).__init__(config) - - split_cache = UWSGISplitCache(uwsgi) - split_fetcher = CacheBasedSplitFetcher(split_cache) - - impressions_cache = UWSGIImpressionsCache(uwsgi) - delegate_treatment_log = CacheBasedTreatmentLog(impressions_cache) - treatment_log = AsyncTreatmentLog(delegate_treatment_log) - - metrics_cache = UWSGIMetricsCache(uwsgi) - delegate_metrics = CacheBasedMetrics(metrics_cache) - metrics = AsyncMetrics(delegate_metrics) - - self._event_log = UWSGIEventsCache(uwsgi, events_queue_size=config['eventsQueueSize']) - - self._split_fetcher = split_fetcher - self._treatment_log = treatment_log - self._metrics = metrics - - def get_split_fetcher(self): - """ - Get the split fetcher implementation for the broker. - :return: The split fetcher - :rtype: SplitFetcher - """ - return self._split_fetcher - - def get_metrics_handler(self): - """ - """ - return self._metrics - - def get_impression_log(self): - """ - """ - return self._treatment_log - - def get_events_log(self): - return self._event_log - - def destroy(self): - self._destroyed = True - self._split_fetcher.destroy() - self._treatment_log.destroy() - self._metrics.destroy() - - -def _init_config(api_key, **kwargs): - config = dict(DEFAULT_CONFIG) - user_cfg = kwargs.pop('config', dict()) - config.update(user_cfg) - sdk_api_base_url = kwargs.pop('sdk_api_base_url', None) - events_api_base_url = kwargs.pop('events_api_base_url', None) - - if 'config_file' in kwargs: - file_config = parse_config_file(kwargs['config_file']) - - file_api_key = file_config.pop('apiKey', None) - file_sdk_api_base_url = file_config.pop('sdkApiBaseUrl', None) - file_events_api_base_url = file_config.pop('eventsApiBaseUrl', None) - - api_key = api_key or file_api_key - sdk_api_base_url = sdk_api_base_url or file_sdk_api_base_url - events_api_base_url = events_api_base_url or file_events_api_base_url - - file_config.update(config) - config = file_config - - set_machine_ip(config.get('splitSdkMachineIp')) - set_machine_name(config.get('splitSdkMachineName')) - - return api_key, config, sdk_api_base_url, events_api_base_url - - -def get_self_refreshing_broker(api_key, **kwargs): - """ - Builds a Split Broker that refreshes itself at regular intervals. - - The config_file parameter is the name of a file that contains the broker - configuration. Here's an example of a config file: - - { - "apiKey": "some-api-key", - "sdkApiBaseUrl": "https://sdk.split.io/api", - "eventsApiBaseUrl": "https://events.split.io/api", - "connectionTimeout": 1500, - "readTimeout": 1500, - "featuresRefreshRate": 5, - "segmentsRefreshRate": 60, - "metricsRefreshRate": 60, - "impressionsRefreshRate": 60, - "randomizeIntervals": False, - "maxImpressionsLogSize": -1, - "maxMetricsCallsBeforeFlush": 1000, - "ready": 0 - } - - The config parameter is a dictionary that allows you to control the - behaviour of the broker. The following configuration values are supported: - - * connectionTimeout: The TCP connection timeout (Default: 1500ms) - * readTimeout: The HTTP read timeout (Default: 1500ms) - * featuresRefreshRate: The refresh rate for features (Default: 5s) - * segmentsRefreshRate: The refresh rate for segments (Default: 60s) - * metricsRefreshRate: The refresh rate for metrics (Default: 60s) - * impressionsRefreshRate: The refresh rate for impressions (Default: 60s) - * randomizeIntervals: Whether to randomize the refres intervals - (Default: False) - * ready: How long to wait (in seconds) for the broker to be initialized. - 0 to return immediately without waiting. (Default: 0s) - - If the api_key argument is 'localhost' a localhost environment broker is - built based on the contents of a .split file in the user's home directory. - The definition file has the following syntax: - - file: (comment | split_line)+ - comment : '#' string*\n - split_line : feature_name ' ' treatment\n - feature_name : string - treatment : string - - It is possible to change the location of the split file by using the - split_definition_file_name argument. - - :param api_key: The API key provided by Split.io - :type api_key: str - :param config_file: Filename of the config file - :type config_file: str - :param config: The configuration dictionary - :type config: dict - :param sdk_api_base_url: An override for the default API base URL. - :type sdk_api_base_url: str - :param events_api_base_url: An override for the default events base URL. - :type events_api_base_url: str - :param split_definition_file_name: Name of the definition file (Optional) - :type split_definition_file_name: str - """ - api_key, config, sdk_api_base_url, events_api_base_url = _init_config( - api_key, - **kwargs - ) - - if api_key == 'localhost': - return LocalhostBroker(config, **kwargs) - - return SelfRefreshingBroker( - api_key, - config=config, - sdk_api_base_url=sdk_api_base_url, - events_api_base_url=events_api_base_url, - impression_listener=kwargs.get('impression_listener') - ) - - -def get_redis_broker(api_key, **kwargs): - """ - Builds a Split Broker that that gets its information from a Redis instance. - It also writes impressions and metrics to the same instance. - - In order for this work properly, you need to periodically call the - update_splits and update_segments scripts. - You also need to run the send_impressions and send_metrics scripts in order - to push the impressions and metrics onto the Split.io backend. - - The config_file parameter is the name of a file that contains the broker - configuration. Here's an example of a config file: - - { - "apiKey": "some-api-key", - "sdkApiBaseUrl": "https://sdk.split.io/api", - "eventsApiBaseUrl": "https://events.split.io/api", - "redisFactory": 'some.redis.factory', - "redisHost": "localhost", - "redisPort": 6879, - "redisDb": 0, - } - - If the redisFactory entry is present, it is used to build the redis broker - instance, otherwise the values of redisHost, redisPort and redisDb are used. - - If the api_key argument is 'localhost' a localhost environment broker is - built based on the contents of a .split file in the user's home directory. - The definition file has the following syntax: - - file: (comment | split_line)+ - comment : '#' string*\n - split_line : feature_name ' ' treatment\n - feature_name : string - treatment : string - - It is possible to change the location of the split file by using the - split_definition_file_name argument. - - :param api_key: The API key provided by Split.io - :type api_key: str - :param config_file: Filename of the config file - :type config_file: str - :param sdk_api_base_url: An override for the default API base URL. - :type sdk_api_base_url: str - :param events_api_base_url: An override for the default events base URL. - :type events_api_base_url: str - :param split_definition_file_name: Name of the definition file (Optional) - :type split_definition_file_name: str - """ - api_key, config, _, _ = _init_config(api_key, **kwargs) - - if api_key == 'localhost': - return LocalhostBroker(config, **kwargs) - - redis = get_redis(config) - - redis_broker = RedisBroker(redis, config) - - return redis_broker - - -def get_uwsgi_broker(api_key, **kwargs): - """ - Builds a Split Broker that that gets its information from a uWSGI cache - instance. It also writes impressions and metrics to the same instance. - - In order for this work properly, you need to periodically call the spooler - uwsgi_update_splits and uwsgi_update_segments scripts. - You also need to run the uwsgi_report_impressions and uwsgi_report_metrics - scripts in order to push the impressions and metrics onto the Split.io - backend. - - The config_file parameter is the name of a file that contains the broker - configuration. Here's an example of a config file: - - { - "apiKey": "some-api-key", - "sdkApiBaseUrl": "https://sdk.split.io/api", - "eventsApiBaseUrl": "https://events.split.io/api", - "featuresRefreshRate": 5, - "segmentsRefreshRate": 60, - "metricsRefreshRate": 60, - "impressionsRefreshRate": 60 - } - - If the api_key argument is 'localhost' a localhost environment broker is - built based on the contents of a .split file in the user's home directory. - The definition file has the following syntax: - - file: (comment | split_line)+ - comment : '#' string*\n - split_line : feature_name ' ' treatment\n - feature_name : string - treatment : string - - It is possible to change the location of the split file by using the - split_definition_file_name argument. - - :param api_key: The API key provided by Split.io - :type api_key: str - :param config_file: Filename of the config file - :type config_file: str - :param sdk_api_base_url: An override for the default API base URL. - :type sdk_api_base_url: str - :param events_api_base_url: An override for the default events base URL. - :type events_api_base_url: str - :param split_definition_file_name: Name of the definition file (Optional) - :type split_definition_file_name: str - """ - api_key, config, _, _ = _init_config(api_key, **kwargs) - - if api_key == 'localhost': - return LocalhostBroker(config, **kwargs) - - uwsgi = get_uwsgi() - uwsgi_broker = UWSGIBroker(uwsgi, config) - - return uwsgi_broker diff --git a/splitio/cache.py b/splitio/cache.py deleted file mode 100644 index 29e6f49b..00000000 --- a/splitio/cache.py +++ /dev/null @@ -1,339 +0,0 @@ -"""This module contains everything related split and segment caches""" -from __future__ import absolute_import, division, print_function, unicode_literals - - -from collections import defaultdict -from copy import deepcopy -from threading import RLock - - -class SplitCache(object): # pragma: no cover - """ - The basic interface for a Split cache. It should be able to store and retrieve Split - instances, as well as keeping track of the change number. - """ - def add_split(self, split_name, split): - """ - Stores a Split under a name. - :param split_name: Name of the split (feature) - :type split_name: str - :param split: The split to store - :type split: Split - """ - pass # Do nothing - - def remove_split(self, split_name): - """ - Evicts a Split from the cache. - :param split_name: Name of the split (feature) - :type split_name: str - """ - pass # Do nothing - - def get_split(self, split_name): - """ - Retrieves a Split from the cache. - :param split_name: Name of the split (feature) - :type split_name: str - :return: The split under the name if it exists, None otherwise - :rtype: Split - """ - return None - - def set_change_number(self, change_number): - """ - Sets the value for the change number - :param change_number: The change number - :type change_number: int - """ - pass # Do nothing - - def get_change_number(self): - """ - Retrieves the value of the change number - :return: The current change number value, -1 otherwise - :rtype: int - """ - return -1 - - -class SegmentCache(object): # pragma: no cover - """ - The basic interface for a Segment cache. It should be able to store and retrieve Segment - information, as well as keeping track of the change number. - """ - def add_keys_to_segment(self, segment_name, segment_keys): - """ - Adds a set of keys to a segment - :param segment_name: Name of the segment - :type segment_name: str - :param segment_keys: Keys to add to the segment - :type segment_keys: list - """ - pass # Do nothing - - def remove_keys_from_segment(self, segment_name, segment_keys): - """ - Removes a set of keys from a segment - :param segment_name: Name of the segment - :type segment_name: str - :param segment_keys: Keys to remove from the segment - :type segment_keys: list - """ - pass # Do nothing - - def is_in_segment(self, segment_name, key): - """ - Checks if a key is in a segment - :param segment_name: Name of the segment - :type segment_name: str - :param key: Key to check - :type key: str - :return: True if the key is in the segment, False otherwise - :rtype: bool - """ - return False - - def set_change_number(self, segment_name, change_number): - """ - Sets the value for the change number - :param segment_name: Name of the segment - :type segment_name: str - :param change_number: The change number - :type change_number: int - """ - pass # Do nothing - - def get_change_number(self, segment_name): - """ - Retrieves the value of the change number of a segment - :param segment_name: Name of the segment - :type segment_name: str - :return: The current change number value, -1 otherwise - :rtype: int - """ - return -1 - - -class InMemorySplitCache(SplitCache): - def __init__(self, change_number=-1, entries=None): - """ - A SplitCache that stores splits in a dictionary. - :param change_number: Initial value for the change number. - :type change_number: int - :param entries: Initial set of dictionary entries - :type entries: dict - """ - self._change_number = change_number - self._entries = entries if entries is not None else dict() - - def add_split(self, split_name, split): - self._entries[split_name] = split - - def remove_split(self, split_name): - self._entries.pop(split_name, None) - - def get_split(self, split_name): - return self._entries.get(split_name, None) - - def set_change_number(self, change_number): - self._change_number = change_number - - def get_change_number(self): - return self._change_number - - -class InMemorySegmentCache(SegmentCache): - def __init__(self): - """A SegmentCache implementation that stores segments in a dictionary""" - self._entries = defaultdict(lambda: {'change_number': -1, 'key_set': frozenset()}) - - def add_keys_to_segment(self, segment_name, segment_keys): - segment = self._entries[segment_name] - segment['key_set'] = segment['key_set'] | frozenset(segment_keys) - - def remove_keys_from_segment(self, segment_name, segment_keys): - segment = self._entries[segment_name] - segment['key_set'] = segment['key_set'] - frozenset(segment_keys) - - def is_in_segment(self, segment_name, key): - return key in self._entries[segment_name]['key_set'] - - def set_change_number(self, segment_name, change_number): - self._entries[segment_name]['change_number'] = change_number - - def get_change_number(self, segment_name): - return self._entries[segment_name]['change_number'] - - -class ImpressionsCache(object): # pragma: no cover - """The basic interface for an Impressions cache.""" - def add_impression(self, impression): - """Add an impression to a feature - :param impression: An impression - :type impression: Impression - :return: How many impressions have been added so far - :rtype: int - """ - pass # Do nothing - - def add_impressions(self, impressions): - """ - Adds impression to the queue if it is enabled, otherwise the impression - is dropped. - :param impressions: The impression bulk - :type impressions: list - """ - for impression in impressions: - self.add_impression(impression) - - def fetch_all(self): - """ List all impressions. - :return: A list of Impression tuples - :rtype: list - """ - return [] - - def clear(self): - """Clears all impressions.""" - pass # Do nothing - - def fetch_all_and_clear(self): - """ List all impressions and clear the cache. - :return: A list of Impression tuples - :rtype: list - """ - return [] - - -class InMemoryImpressionsCache(ImpressionsCache): # pragma: no cover - def __init__(self, impressions=None): - """An in memory implementation of an Impressions cache. - :param impressions: Initial set of impressions - :type impressions: dict - """ - self._impressions = defaultdict(list) - if impressions is not None: - self._impressions.update(impressions) - self._rlock = RLock() - - def add_impression(self, impression): - """Add an impression to a feature - :param impression: An impression - :type impression: Impression - """ - with self._rlock: - self._impressions[impression.feature].append(impression) - - def fetch_all(self): - """ List all impressions. - :return: A list of Impression tuples - :rtype: dict - """ - return deepcopy(self._impressions) - - def clear(self): - """Clears all impressions.""" - with self._rlock: - self._impressions = defaultdict(list) - - def fetch_all_and_clear(self): - """ List all impressions and clear the cache. - :return: A list of Impression tuples - :rtype: list - """ - with self._rlock: - impressions = self.fetch_all() - self.clear() - - return impressions - - -class MetricsCache(object): # pragma: no cover - """A default implementation of a Metrics cache.""" - def set_count(self, counter, value): - """Sets a counter value. - :param counter: Name of the counter - :type counter: str - :param value: Value for the counter - :type value: 1 - """ - pass # Do nothing - - def increment_count(self, counter, delta=1): - """Increments the value of a counter by a given value. - :param counter: Name of the counter - :type counter: str - :param delta: The value to be added to the counter - :type delta: int - """ - pass # Do nothing - - def get_count(self, counter): - """ - :param counter: Name of the counter - :type counter: str - :return: The current value of the counter - :rtype: int - """ - return 0 - - def set_gauge(self, gauge, value): - """Sets the value of a gauge. - :param gauge: The name of the gauge - :type gauge: str - :param value: The value of the gauge - :type value: float - """ - pass # Do nothing - - def get_gauge(self, gauge): - """ - :param gauge: The name of the gauge - :type gauge: str - :return: The current value of the gauge - :rtype: float - """ - return 0 - - def set_latency_bucket_counter(self, operation, bucket_index, value): - """Sets the value of a bucket of a latency tracker for an operation. - :param operation: The name of the operation - :type operation: str - :param bucket_index: The index for the latency bucket - :type bucket_index: int - :param value: The new value for the bucket - :type value: int - """ - pass # Do nothing - - def increment_latency_bucket_counter(self, operation, bucket_index, delta=1): - """Increments the value of a bucket of a latency tracker for an operation - :param operation: The name of the operation - :type operation: str - :param bucket_index: The index for the latency bucket - :type bucket_index: int - :param delta: The value to add to the bucket - :type delta: int - """ - pass # Do nothing - - def get_latency_bucket_counter(self, operation, bucket_index): - """ - :param operation: The name of the operation - :type operation: str - :param bucket_index: The index for the latency bucket - :type bucket_index: int - :return: The current value of a bucket of a latency tracker - :rtype: int - """ - return 0 - - def get_latency(self, operation): - """ - :param operation: The name of the operation - :type operation: str - :return: All the buckets of a latency tracker - :rtype: list - """ - return [0] * 23 diff --git a/splitio/bin/__init__.py b/splitio/client/__init__.py similarity index 100% rename from splitio/bin/__init__.py rename to splitio/client/__init__.py diff --git a/splitio/clients.py b/splitio/client/client.py similarity index 62% rename from splitio/clients.py rename to splitio/client/client.py index 2168a113..ecee50a3 100644 --- a/splitio/clients.py +++ b/splitio/client/client.py @@ -4,24 +4,27 @@ import logging import time -from splitio.treatments import CONTROL -from splitio.splitters import Splitter -from splitio.impressions import Impression, Label, ImpressionListenerException -from splitio.metrics import SDK_GET_TREATMENT, SDK_GET_TREATMENTS -from splitio.events import Event -from . import input_validator -from splitio.evaluator import Evaluator +from splitio.engine.evaluator import Evaluator, CONTROL +from splitio.engine.splitters import Splitter +from splitio.models.impressions import Impression, Label +from splitio.models.events import Event +from splitio.models.telemetry import get_latency_bucket_index +from splitio.client import input_validator +from splitio.client.listener import ImpressionListenerException -class Client(object): - """Client class that uses a broker for storage.""" +class Client(object): #pylint: disable=too-many-instance-attributes + """Entry point for the split sdk.""" - def __init__(self, broker, labels_enabled=True, impression_listener=None): + _METRIC_GET_TREATMENT = 'sdk.getTreatment' + _METRIC_GET_TREATMENTS = 'sdk.getTreatments' + + def __init__(self, factory, labels_enabled=True, impression_listener=None): """ Construct a Client instance. - :param broker: Broker that accepts/retrieves splits, segments, events, metrics & impressions - :type broker: BaseBroker + :param factory: Split factory (client & manager container) + :type factory: splitio.client.factory.SplitFactory :param labels_enabled: Whether to store labels on impressions :type labels_enabled: bool @@ -32,37 +35,49 @@ def __init__(self, broker, labels_enabled=True, impression_listener=None): :rtype: Client """ self._logger = logging.getLogger(self.__class__.__name__) - self._splitter = Splitter() - self._broker = broker + self._factory = factory self._labels_enabled = labels_enabled - self._destroyed = False self._impression_listener = impression_listener - self._evaluator = Evaluator(broker) + + self._splitter = Splitter() + self._split_storage = factory._get_storage('splits') #pylint: disable=protected-access + self._segment_storage = factory._get_storage('segments') #pylint: disable=protected-access + self._impressions_storage = factory._get_storage('impressions') #pylint: disable=protected-access + self._events_storage = factory._get_storage('events') #pylint: disable=protected-access + self._telemetry_storage = factory._get_storage('telemetry') #pylint: disable=protected-access + self._evaluator = Evaluator(self._split_storage, self._segment_storage, self._splitter) def destroy(self): """ - Disable the split-client and free all allocated resources. + Destroy the underlying factory. Only applicable when using in-memory operation mode. """ - self._destroyed = True - self._broker.destroy() + self._factory.destroy() + + @property + def destroyed(self): + """Return whether the factory holding this client has been destroyed.""" + return self._factory.destroyed def _send_impression_to_listener(self, impression, attributes): - ''' - Sends impression result to custom listener. + """ + Send impression result to custom listener. :param impression: Generated impression :type impression: Impression :param attributes: An optional dictionary of attributes :type attributes: dict - ''' + """ if self._impression_listener is not None: try: self._impression_listener.log_impression(impression, attributes) - except ImpressionListenerException as e: - self._logger.error(e) + except ImpressionListenerException: + self._logger.error( + 'An exception was raised while calling user-custom impression listener' + ) + self._logger.debug('Error', exc_info=True) def get_treatment(self, key, feature, attributes=None): """ @@ -80,20 +95,21 @@ def get_treatment(self, key, feature, attributes=None): :return: The treatment for the key and feature :rtype: str """ - if self._destroyed: - self._logger.error("Client has already been destroyed - no calls possible") - return CONTROL + try: + if self.destroyed: + self._logger.error("Client has already been destroyed - no calls possible") + return CONTROL - start = int(round(time.time() * 1000)) + start = int(round(time.time() * 1000)) - matching_key, bucketing_key = input_validator.validate_key(key, 'get_treatment') - feature = input_validator.validate_feature_name(feature) + matching_key, bucketing_key = input_validator.validate_key(key, 'get_treatment') + feature = input_validator.validate_feature_name(feature) - if (matching_key is None and bucketing_key is None) or feature is None or\ - input_validator.validate_attributes(attributes, 'get_treatment') is False: - return CONTROL + if (matching_key is None and bucketing_key is None) \ + or feature is None \ + or not input_validator.validate_attributes(attributes, 'get_treatment'): + return CONTROL - try: result = self._evaluator.evaluate_treatment( feature, matching_key, @@ -101,39 +117,43 @@ def get_treatment(self, key, feature, attributes=None): attributes ) - impression = self._build_impression(matching_key, - feature, - result['treatment'], - result['impression']['label'], - result['impression']['change_number'], - bucketing_key, - start) - - self._record_stats(impression, start, SDK_GET_TREATMENT) + impression = self._build_impression( + matching_key, + feature, + result['treatment'], + result['impression']['label'], + result['impression']['change_number'], + bucketing_key, + start + ) + self._record_stats(impression, start, self._METRIC_GET_TREATMENT) self._send_impression_to_listener(impression, attributes) - return result['treatment'] - except Exception: + except Exception: #pylint: disable=broad-except self._logger.error('Error getting treatment for feature') - + self._logger.debug('Error: ', exc_info=True) try: impression = self._build_impression( matching_key, feature, CONTROL, Label.EXCEPTION, - self._broker.get_change_number(), bucketing_key, start + self._split_storage.get_change_number(), + bucketing_key, + start ) - self._record_stats(impression, start, SDK_GET_TREATMENT) - + self._record_stats(impression, start, self._METRIC_GET_TREATMENT) self._send_impression_to_listener(impression, attributes) except Exception: # pylint: disable=broad-except self._logger.error('Error reporting impression into get_treatment exception block') + self._logger.debug('Error: ', exc_info=True) return CONTROL def get_treatments(self, key, features, attributes=None): """ + Evaluate multiple features and return a dictionary with all the feature/treatments. + Get the treatments for a list of features considering a key, with an optional dictionary of attributes. This method never raises an exception. If there's a problem, the appropriate log message will be generated and the method will return the CONTROL treatment. @@ -146,7 +166,7 @@ def get_treatments(self, key, features, attributes=None): :return: Dictionary with the result of all the features provided :rtype: dict """ - if self._destroyed: + if self.destroyed: self._logger.error("Client has already been destroyed - no calls possible") return input_validator.generate_control_treatments(features) @@ -156,7 +176,7 @@ def get_treatments(self, key, features, attributes=None): if matching_key is None and bucketing_key is None: return input_validator.generate_control_treatments(features) - if input_validator.validate_attributes(attributes, 'get_treatment') is False: + if input_validator.validate_attributes(attributes, 'get_treatments') is False: return input_validator.generate_control_treatments(features) features = input_validator.validate_features_get_treatments(features) @@ -186,32 +206,37 @@ def get_treatments(self, key, features, attributes=None): bulk_impressions.append(impression) treatments[feature] = treatment['treatment'] - except Exception: + except Exception: #pylint: disable=broad-except self._logger.error('get_treatments: An exception occured when evaluating ' 'feature ' + feature + ' returning CONTROL.') treatments[feature] = CONTROL + self._logger.debug('Error: ', exc_info=True) continue # Register impressions try: - if len(bulk_impressions) > 0: - self._record_stats(bulk_impressions, start, SDK_GET_TREATMENTS) - + if bulk_impressions: + self._record_stats(bulk_impressions, start, self._METRIC_GET_TREATMENTS) for impression in bulk_impressions: self._send_impression_to_listener(impression, attributes) - except Exception: + except Exception: #pylint: disable=broad-except self._logger.error('get_treatments: An exception when trying to store ' 'impressions.') + self._logger.debug('Error: ', exc_info=True) return treatments - def _build_impression( - self, matching_key, feature_name, treatment, label, - change_number, bucketing_key, imp_time + def _build_impression( #pylint: disable=too-many-arguments + self, + matching_key, + feature_name, + treatment, + label, + change_number, + bucketing_key, + imp_time ): - """ - Build an impression. - """ + """Build an impression.""" if not self._labels_enabled: label = None @@ -236,13 +261,14 @@ def _record_stats(self, impressions, start, operation): """ try: end = int(round(time.time() * 1000)) - if operation == SDK_GET_TREATMENT: - self._broker.log_impressions([impressions]) + if operation == self._METRIC_GET_TREATMENT: + self._impressions_storage.put([impressions]) else: - self._broker.log_impressions(impressions) - self._broker.log_operation_time(operation, end - start) - except Exception: + self._impressions_storage.put(impressions) + self._telemetry_storage.inc_latency(operation, get_latency_bucket_index(end - start)) + except Exception: #pylint: disable=broad-except self._logger.error('Error recording impressions and metrics') + self._logger.debug('Error: ', exc_info=True) def track(self, key, traffic_type, event_type, value=None): """ @@ -250,19 +276,17 @@ def track(self, key, traffic_type, event_type, value=None): :param key: user key associated to the event :type key: str - :param traffic_type: traffic type name :type traffic_type: str - :param event_type: event type name :type event_type: str - :param value: (Optional) value associated to the event :type value: Number + :return: Whether the event was created or not. :rtype: bool """ - if self._destroyed: + if self.destroyed: self._logger.error("Client has already been destroyed - no calls possible") return False @@ -276,9 +300,9 @@ def track(self, key, traffic_type, event_type, value=None): event = Event( key=key, - trafficTypeName=traffic_type, - eventTypeId=event_type, + traffic_type_name=traffic_type, + event_type_id=event_type, value=value, timestamp=int(time.time()*1000) ) - return self._broker.get_events_log().log_event(event) + return self._events_storage.put([event]) diff --git a/splitio/client/config.py b/splitio/client/config.py new file mode 100644 index 00000000..d91b4f39 --- /dev/null +++ b/splitio/client/config.py @@ -0,0 +1,43 @@ +"""Default settings for the Split.IO SDK Python client.""" +from __future__ import absolute_import, division, print_function, unicode_literals +import os.path + +DEFAULT_CONFIG = { + 'connectionTimeout': 1500, + 'splitSdkMachineName': None, + 'splitSdkMachineIp': None, + 'featuresRefreshRate': 5, + 'segmentsRefreshRate': 60, + 'metricsRefreshRate': 60, + 'impressionsRefreshRate': 10, + 'impressionsBulkSize': 5000, + 'impressionsQueueSize': 10000, + 'eventsPushRate': 10, + 'eventsBulkSize': 5000, + 'eventsQueueSize': 10000, + 'labelsEnabled': True, + 'impressionListener': None, + 'redisHost': 'localhost', + 'redisPort': 6379, + 'redisDb': 0, + 'redisPassword': None, + 'redisSocketTimeout': None, + 'redisSocketConnectTimeout': None, + 'redisSocketKeepalive': None, + 'redisSocketKeepaliveOptions': None, + 'redisConnectionPool': None, + 'redisUnixSocketPath': None, + 'redisEncoding': 'utf-8', + 'redisEncodingErrors': 'strict', + 'redisCharset': None, + 'redisErrors': None, + 'redisDecodeResponses': False, + 'redisRetryOnTimeout': False, + 'redisSsl': False, + 'redisSslKeyfile': None, + 'redisSslCertfile': None, + 'redisSslCertReqs': None, + 'redisSslCaCerts': None, + 'redisMaxConnections': None, + 'splitFile': os.path.join(os.path.expanduser('~'), '.split') +} diff --git a/splitio/client/factory.py b/splitio/client/factory.py new file mode 100644 index 00000000..4283866c --- /dev/null +++ b/splitio/client/factory.py @@ -0,0 +1,361 @@ +"""A module for Split.io Factories.""" +from __future__ import absolute_import, division, print_function, unicode_literals + + +import logging +import threading +from enum import Enum + +import six + +from splitio.client.client import Client +from splitio.client.manager import SplitManager +from splitio.client.config import DEFAULT_CONFIG +from splitio.client import util + +#Storage +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ + InMemoryImpressionStorage, InMemoryEventStorage, InMemoryTelemetryStorage +from splitio.storage.adapters import redis +from splitio.storage.redis import RedisSplitStorage, RedisSegmentStorage, RedisImpressionsStorage, \ + RedisEventsStorage, RedisTelemetryStorage +from splitio.storage.adapters.uwsgi_cache import get_uwsgi +from splitio.storage.uwsgi import UWSGIEventStorage, UWSGIImpressionStorage, UWSGISegmentStorage, \ + UWSGISplitStorage, UWSGITelemetryStorage + +# APIs +from splitio.api.client import HttpClient +from splitio.api.splits import SplitsAPI +from splitio.api.segments import SegmentsAPI +from splitio.api.impressions import ImpressionsAPI +from splitio.api.events import EventsAPI +from splitio.api.telemetry import TelemetryAPI + +# Tasks +from splitio.tasks.split_sync import SplitSynchronizationTask +from splitio.tasks.segment_sync import SegmentSynchronizationTask +from splitio.tasks.impressions_sync import ImpressionsSyncTask +from splitio.tasks.events_sync import EventsSyncTask +from splitio.tasks.telemetry_sync import TelemetrySynchronizationTask + +# Localhost stuff +from splitio.client.localhost import LocalhostEventsStorage, LocalhostImpressionsStorage, \ + LocalhostSplitSynchronizationTask, LocalhostTelemetryStorage + + +class Status(Enum): + """Factory Status.""" + + NOT_INITIALIZED = 'NOT_INITIALIZED' + READY = 'READY' + DESTROYED = 'DESTROYED' + + +class TimeoutException(Exception): + """Exception to be raised upon a block_until_ready call when a timeout expires.""" + + pass + + +class SplitFactory(object): #pylint: disable=too-many-instance-attributes + """Split Factory/Container class.""" + + def __init__( #pylint: disable=too-many-arguments + self, + storages, + labels_enabled, + apis=None, + tasks=None, + sdk_ready_flag=None, + impression_listener=None + ): + """ + Class constructor. + + :param storages: Dictionary of storages for all split models. + :type storages: dict + :param tasks: Dictionary of synchronization tasks. + :type tasks: dict + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._storages = storages + self._labels_enabled = labels_enabled + self._apis = apis if apis else {} + self._tasks = tasks if tasks else {} + self._status = Status.NOT_INITIALIZED + self._sdk_ready_flag = sdk_ready_flag + self._impression_listener = impression_listener + + # If we have a ready flag, add a listener that updates the status + # to READY once the flag is set. + if self._sdk_ready_flag is not None: + ready_updater = threading.Thread(target=self._update_status_when_ready) + ready_updater.setDaemon(True) + ready_updater.start() + + def _update_status_when_ready(self): + """Wait until the sdk is ready and update the status.""" + self._sdk_ready_flag.wait() + self._status = Status.READY + + def _get_storage(self, name): + """ + Return a reference to the specified storage. + + :param name: Name of the requested storage. + :type name: str + + :return: requested factory. + :rtype: object + """ + return self._storages[name] + + def client(self): + """ + Return a new client. + + This client is only a set of references to structures hold by the factory. + Creating one a fast operation and safe to be used anywhere. + """ + return Client(self, self._labels_enabled, self._impression_listener) + + def manager(self): + """ + Return a new manager. + + This manager is only a set of references to structures hold by the factory. + Creating one a fast operation and safe to be used anywhere. + """ + return SplitManager(self) + + def block_until_ready(self, timeout=None): + """ + Blocks until the sdk is ready or the timeout specified by the user expires. + + :param timeout: Number of seconds to wait (fractions allowed) + :type timeout: int + """ + if self._sdk_ready_flag is not None: + ready = self._sdk_ready_flag.wait(timeout) + + if not ready: + raise TimeoutException('Waited %d seconds, and sdk was not ready') + + def destroy(self, destroyed_event=None): + """ + Destroy the factory and render clients unusable. + + Destroy frees up storage taken but split data, flushes impressions & events, + and invalidates the clients, making them return control. + + :param destroyed_event: Event to signal when destroy process has finished. + :type destroyed_event: threading.Event + """ + if self.destroyed: + self._logger.info('Factory already destroyed.') + return + + if destroyed_event is not None: + stop_events = {name: threading.Event() for name in self._tasks.keys()} + for name, task in six.iteritems(self._tasks): + task.stop(stop_events[name]) + + def _wait_for_tasks_to_stop(): + for event in stop_events.values(): + event.wait() + destroyed_event.set() + + wait_thread = threading.Thread(target=_wait_for_tasks_to_stop) + wait_thread.setDaemon(True) + wait_thread.start() + else: + for task in self._tasks.values(): + task.stop() + + self._status = Status.DESTROYED + + @property + def destroyed(self): + """ + Return whether the factory has been destroyed or not. + + :return: True if the factory has been destroyed. False otherwise. + :rtype: bool + """ + return self._status == Status.DESTROYED + + +def _build_in_memory_factory(api_key, config, sdk_url=None, events_url=None): #pylint: disable=too-many-locals + """Build and return a split factory tailored to the supplied config.""" + cfg = DEFAULT_CONFIG.copy() + cfg.update(config) + http_client = HttpClient( + sdk_url=sdk_url, + events_url=events_url, + timeout=cfg.get('connectionTimeout') + ) + + sdk_metadata = util.get_metadata(config) + apis = { + 'splits': SplitsAPI(http_client, api_key), + 'segments': SegmentsAPI(http_client, api_key), + 'impressions': ImpressionsAPI(http_client, api_key, sdk_metadata), + 'events': EventsAPI(http_client, api_key, sdk_metadata), + 'telemetry': TelemetryAPI(http_client, api_key, sdk_metadata) + } + + storages = { + 'splits': InMemorySplitStorage(), + 'segments': InMemorySegmentStorage(), + 'impressions': InMemoryImpressionStorage(cfg['impressionsQueueSize']), + 'events': InMemoryEventStorage(cfg['eventsQueueSize']), + 'telemetry': InMemoryTelemetryStorage() + } + + # Synchronization flags + splits_ready_flag = threading.Event() + segments_ready_flag = threading.Event() + sdk_ready_flag = threading.Event() + + tasks = { + 'splits': SplitSynchronizationTask( + apis['splits'], + storages['splits'], + cfg['featuresRefreshRate'], + splits_ready_flag + ), + + 'segments': SegmentSynchronizationTask( + apis['segments'], + storages['segments'], + storages['splits'], + cfg['segmentsRefreshRate'], + segments_ready_flag + ), + + 'impressions': ImpressionsSyncTask( + apis['impressions'], + storages['impressions'], + cfg['impressionsRefreshRate'], + cfg['impressionsBulkSize'] + ), + + 'events': EventsSyncTask( + apis['events'], + storages['events'], + cfg['eventsPushRate'], + cfg['eventsBulkSize'], + ), + + 'telemetry': TelemetrySynchronizationTask( + apis['telemetry'], + storages['telemetry'], + cfg['metricsRefreshRate'] + ) + } + + # Start tasks that have no dependencies + tasks['impressions'].start() + tasks['events'].start() + tasks['splits'].start() + tasks['telemetry'].start() + + def split_ready_task(): + """Wait for splits to be ready and start fetching segments.""" + splits_ready_flag.wait() + tasks['segments'].start() + + def segment_ready_task(): + """Wait for segments to be ready and set the main ready flag.""" + segments_ready_flag.wait() + sdk_ready_flag.set() + + split_completion_thread = threading.Thread(target=split_ready_task) + split_completion_thread.setDaemon(True) + split_completion_thread.start() + segment_completion_thread = threading.Thread(target=segment_ready_task) + segment_completion_thread.setDaemon(True) + segment_completion_thread.start() + return SplitFactory(storages, cfg['labelsEnabled'], apis, tasks, sdk_ready_flag) + + +def _build_redis_factory(config): + """Build and return a split factory with redis-based storage.""" + cfg = DEFAULT_CONFIG.copy() + cfg.update(config) + sdk_metadata = util.get_metadata() + redis_adapter = redis.build(config) + storages = { + 'splits': RedisSplitStorage(redis_adapter), + 'segments': RedisSegmentStorage(redis_adapter), + 'impressions': RedisImpressionsStorage(redis_adapter, sdk_metadata), + 'events': RedisEventsStorage(redis_adapter, sdk_metadata), + 'telemetry': RedisTelemetryStorage(redis_adapter, sdk_metadata) + } + return SplitFactory( + storages, + cfg['labelsEnabled'], + impression_listener=cfg['impressionListener'] + ) + + +def _build_uwsgi_factory(config): + """Build and return a split factory with redis-based storage.""" + cfg = DEFAULT_CONFIG.copy() + cfg.update(config) + uwsgi_adapter = get_uwsgi() + storages = { + 'splits': UWSGISplitStorage(uwsgi_adapter), + 'segments': UWSGISegmentStorage(uwsgi_adapter), + 'impressions': UWSGIImpressionStorage(uwsgi_adapter), + 'events': UWSGIEventStorage(uwsgi_adapter), + 'telemetry': UWSGITelemetryStorage(uwsgi_adapter) + } + return SplitFactory( + storages, + cfg['labelsEnabled'], + impression_listener=cfg['impressionListener'] + ) + + +def _build_localhost_factory(config): + """Build and return a localhost factory for testing/development purposes.""" + cfg = DEFAULT_CONFIG.copy() + cfg.update(config) + storages = { + 'splits': InMemorySplitStorage(), + 'segments': InMemorySegmentStorage(), # not used, just to avoid possible future errors. + 'impressions': LocalhostImpressionsStorage(), + 'events': LocalhostEventsStorage(), + 'telemetry': LocalhostTelemetryStorage() + } + + ready_event = threading.Event() + tasks = {'splits': LocalhostSplitSynchronizationTask( + cfg['splitFile'], + storages['splits'], + ready_event + )} + tasks['splits'].start() + return SplitFactory(storages, False, None, tasks, ready_event) + + +def get_factory(api_key, **kwargs): + """Build and return the appropriate factory.""" + config = kwargs.get('config', {}) + + if api_key == 'localhost': + return _build_localhost_factory(config) + + if 'redisHost' in config: + return _build_redis_factory(config) + + if 'uwsgiCache' in config: + return _build_uwsgi_factory(config) + + return _build_in_memory_factory( + api_key, + config, + kwargs.get('sdk_api_base_url'), + kwargs.get('events_api_base_url') + ) diff --git a/splitio/input_validator.py b/splitio/client/input_validator.py similarity index 93% rename from splitio/input_validator.py rename to splitio/client/input_validator.py index 6f923de4..b692c011 100644 --- a/splitio/input_validator.py +++ b/splitio/client/input_validator.py @@ -7,10 +7,10 @@ import re import math import requests -from splitio.key import Key -from splitio.treatments import CONTROL -from splitio.api import SdkApi -from splitio.exceptions import NetworkingException +from splitio.client.key import Key +from splitio.engine.evaluator import CONTROL +# from splitio.api import SdkApi +# from splitio.exceptions import NetworkingException _LOGGER = logging.getLogger(__name__) @@ -393,25 +393,26 @@ def validate_attributes(attributes, operation): return True -def _valid_apikey_type(api_key, sdk_api_base_url): - sdk_api = SdkApi( - api_key, - sdk_api_base_url=sdk_api_base_url, - ) - _SEGMENT_CHANGES_URL_TEMPLATE = '{base_url}/segmentChanges/{segment_name}/' - url = _SEGMENT_CHANGES_URL_TEMPLATE.format(base_url=sdk_api_base_url, - segment_name='___TEST___') - params = { - 'since': -1 - } - headers = sdk_api._build_headers() - try: - response = requests.get(url, params=params, headers=headers, timeout=sdk_api._timeout) - if response.status_code == requests.codes.forbidden: - return False - return True - except requests.exceptions.RequestException: - raise NetworkingException() +# TODO: Fix this! +# def _valid_apikey_type(api_key, sdk_api_base_url): +# sdk_api = SdkApi( +# api_key, +# sdk_api_base_url=sdk_api_base_url, +# ) +# _SEGMENT_CHANGES_URL_TEMPLATE = '{base_url}/segmentChanges/{segment_name}/' +# url = _SEGMENT_CHANGES_URL_TEMPLATE.format(base_url=sdk_api_base_url, +# segment_name='___TEST___') +# params = { +# 'since': -1 +# } +# headers = sdk_api._build_headers() +# try: +# response = requests.get(url, params=params, headers=headers, timeout=sdk_api._timeout) +# if response.status_code == requests.codes.forbidden: +# return False +# return True +# except requests.exceptions.RequestException: +# raise NetworkingException() def validate_factory_instantiation(apikey, config, sdk_api_base_url): diff --git a/splitio/client/key.py b/splitio/client/key.py new file mode 100644 index 00000000..e50d43ba --- /dev/null +++ b/splitio/client/key.py @@ -0,0 +1,22 @@ +"""A module for Split.io SDK API clients.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + + +class Key(object): + """Key class includes a matching key and bucketing key.""" + + def __init__(self, matching_key, bucketing_key): + """Construct a key object.""" + self._matching_key = matching_key + self._bucketing_key = bucketing_key + + @property + def matching_key(self): + """Return matching key.""" + return self._matching_key + + @property + def bucketing_key(self): + """Return bucketing key.""" + return self._bucketing_key diff --git a/splitio/client/listener.py b/splitio/client/listener.py new file mode 100644 index 00000000..1a1fd57a --- /dev/null +++ b/splitio/client/listener.py @@ -0,0 +1,68 @@ +"""Impression listener module.""" + +import abc + + +class ImpressionListenerException(Exception): + """Custom Exception for Impression Listener.""" + + pass + + +class ImpressionListenerWrapper(object): #pylint: disable=too-few-public-methods + """ + Impression listener safe-execution wrapper. + + Wrapper in charge of building all the data that client would require in case + of adding some logic with the treatment and impression results. + """ + + impression_listener = None + + def __init__(self, impression_listener, sdk_metadata): + """ + Class Constructor. + + :param impression_listener: User provided impression listener. + :type impression_listener: ImpressionListener + :param sdk_metadata: SDK version, instance name & IP + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self.impression_listener = impression_listener + self._metadata = sdk_metadata + + def log_impression(self, impression, attributes=None): + """ + Send an impression to the user-provided listener. + + :param impression: Imression data + :type impression: dict + :param attributes: User provided attributes when calling get_treatment(s) + :type attributes: dict + """ + data = {} + data['impression'] = impression + data['attributes'] = attributes + data['sdk-language-version'] = self._metadata.sdk_version + data['instance-id'] = self._metadata.instance_name + try: + self.impression_listener.log_impression(data) + except Exception: + raise ImpressionListenerException('Error in log_impression user\'s' + 'method is throwing exceptions') + + +class ImpressionListener(object): #pylint: disable=too-few-public-methods + """Impression listener interface.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def log_impression(self, data): + """ + Accept and impression generated after an evaluation for custom user handling. + + :param data: Impression data in a dictionary format. + :type data: dict + """ + pass diff --git a/splitio/client/localhost.py b/splitio/client/localhost.py new file mode 100644 index 00000000..9058777e --- /dev/null +++ b/splitio/client/localhost.py @@ -0,0 +1,207 @@ +"""Localhost client mocked components.""" + +import logging +import re +from splitio.models.splits import from_raw +from splitio.storage import ImpressionStorage, EventStorage, TelemetryStorage +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util import asynctask + +_COMMENT_LINE_RE = re.compile('^#.*$') +_DEFINITION_LINE_RE = re.compile('^(?[\w_-]+)\s+(?P[\w_-]+)$') + + +_LOGGER = logging.getLogger(__name__) + + +class LocalhostImpressionsStorage(ImpressionStorage): + """Impression storage that doesn't cache anything.""" + + def put(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def pop_many(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + +class LocalhostEventsStorage(EventStorage): + """Impression storage that doesn't cache anything.""" + + def put(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def pop_many(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + +class LocalhostTelemetryStorage(TelemetryStorage): + """Impression storage that doesn't cache anything.""" + + def inc_latency(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def inc_counter(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def put_gauge(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def pop_latencies(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def pop_counters(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + def pop_gauges(self, *_, **__): #pylint: disable=arguments-differ + """Accept any arguments and do nothing.""" + pass + + +class LocalhostSplitSynchronizationTask(BaseSynchronizationTask): + """Split synchronization task that periodically checks the file and updated the splits.""" + + def __init__(self, filename, storage, ready_event): + """ + Class constructor. + + :param filename: File to parse splits from. + :type filename: str + :param storage: Split storage + :type storage: splitio.storage.SplitStorage + :param ready_event: Eevent to set when sync is done. + :type ready_event: threading.Event + """ + self._filename = filename + self._ready_event = ready_event + self._storage = storage + self._task = asynctask.AsyncTask(self._update_splits, 5, self._on_start) + + def _on_start(self): + """Sync splits and set event if successful.""" + self._update_splits() + self._ready_event.set() + + @staticmethod + def _make_all_keys_based_split(split_name, treatment): + """ + Make a split with a single all_keys matcher. + + :param split_name: Name of the split. + :type split_name: str. + """ + return from_raw({ + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': split_name, + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': treatment, + 'algo': 2, + 'conditions': [ + { + 'partitions': [ + {'treatment': treatment, 'size': 100} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_other_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'ALL_KEYS', + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + ] + }) + + @classmethod + def _read_splits_from_file(cls, filename): + """ + Parse a splits file and return a populated storage. + + :param filename: Path of the file containing mocked splits & treatments. + :type filename: str. + + :return: Storage populataed with splits ready to be evaluated. + :rtype: InMemorySplitStorage + """ + splits = {} + try: + with open(filename, 'r') as flo: + for line in flo: + if line.strip() == '': + continue + + comment_match = _COMMENT_LINE_RE.match(line) + if comment_match: + continue + + definition_match = _DEFINITION_LINE_RE.match(line) + if definition_match: + splits[definition_match.group('feature')] = cls._make_all_keys_based_split( + definition_match.group('feature'), + definition_match.group('treatment') + ) + continue + + _LOGGER.warning( + 'Invalid line on localhost environment split ' + 'definition. Line = %s', + line + ) + return splits + except IOError as e: + raise ValueError("Error parsing split file") + # TODO: ver raise from! +# raise_from(ValueError( +# 'There was a problem with ' +# 'the splits definition file "{}"'.format(filename)), +# e +# ) + + + def _update_splits(self): + """Update splits in storage.""" + _LOGGER.info('Synchronizing splits now.') + splits = self._read_splits_from_file(self._filename) + to_delete = [name for name in self._storage.get_split_names() if name not in splits.keys()] + for split in splits.values(): + self._storage.put(split) + + for split in to_delete: + self._storage.remove(split) + + + def is_running(self): + """Return whether the task is running.""" + return self._task.running + + def start(self): + """Start split synchronization.""" + self._task.start() + + def stop(self, stop_event): + """ + Stop task. + + :param stop_event: Event top set when the task finishes. + :type stop_event: threading.Event. + """ + self._task.stop(stop_event) + + diff --git a/splitio/client/manager.py b/splitio/client/manager.py new file mode 100644 index 00000000..038cef44 --- /dev/null +++ b/splitio/client/manager.py @@ -0,0 +1,68 @@ +"""A module for Split.io Managers.""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging + +from . import input_validator + + +class SplitManager(object): + """Split Manager. Gives insights on data cached by splits.""" + + def __init__(self, factory): + """ + Class constructor. + + :param factory: Factory containing all storage references. + :type factory: splitio.client.factory.SplitFactory + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._factory = factory + self._storage = factory._get_storage('splits') + + def split_names(self): + """ + Get the name of fetched splits. + + :return: A list of str + :rtype: list + """ + if self._factory.destroyed: + self._logger.error("Client has already been destroyed - no calls possible.") + return [] + + return self._storage.get_split_names() + + def splits(self): + """ + Get the fetched splits. Subclasses need to override this method. + + :return: A List of SplitView. + :rtype: list() + """ + if self._factory.destroyed: + self._logger.error("Client has already been destroyed - no calls possible.") + return [] + + return [split.to_split_view() for split in self._storage.get_all_splits()] + + def split(self, feature_name): + """ + Get the splitView of feature_name. Subclasses need to override this method. + + :param feature_name: Name of the feture to retrieve. + :type feature_name: str + + :return: The SplitView instance. + :rtype: splitio.models.splits.SplitView + """ + if self._factory.destroyed: + self._logger.error("Client has already been destroyed - no calls possible.") + return [] + + feature_name = input_validator.validate_manager_feature_name(feature_name) + if feature_name is None: + return None + + split = self._storage.get(feature_name) + return split.to_split_view() if split is not None else None diff --git a/splitio/client/util.py b/splitio/client/util.py new file mode 100644 index 00000000..d92bdb20 --- /dev/null +++ b/splitio/client/util.py @@ -0,0 +1,34 @@ +"""General purpose SDK utilities.""" + +import socket +from collections import namedtuple +from splitio.version import __version__ + +SdkMetadata = namedtuple( + 'SdkMetadata', + ['sdk_version', 'instance_name', 'instance_ip'] +) + + +def _get_ip(): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # doesn't even have to be reachable + sock.connect(('10.255.255.255', 1)) + ip_address = sock.getsockname()[0] + except Exception: #pylint: disable=broad-except + ip_address = 'unknown' + finally: + sock.close() + return ip_address + + +def _get_hostname(ip_address): + return 'unknown' if ip_address == 'unknown' else 'ip-' + ip_address.replace('.', '-') + +def get_metadata(*args, **kwargs): + """Gather SDK metadata and return a tuple with such info.""" + version = 'python-%s' % __version__ + ip_address = _get_ip() + hostname = _get_hostname(ip_address) + return SdkMetadata(version, hostname, ip_address) diff --git a/splitio/config.py b/splitio/config.py deleted file mode 100644 index 595b3737..00000000 --- a/splitio/config.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Default settings for the Split.IO SDK Python client""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import importlib -import logging -import json - -from .version import __version__ -from splitio.utils import get_hostname, get_ip - -logger = logging.getLogger(__name__) - -SDK_API_BASE_URL = 'https://sdk.split.io/api' -EVENTS_API_BASE_URL = 'https://events.split.io/api' - -SDK_VERSION = 'python-{package_version}'.format(package_version=__version__) - -DEFAULT_CONFIG = { - 'connectionTimeout': 1500, - 'sdkApiBaseUrl': SDK_API_BASE_URL, - 'eventsApiBaseUrl': EVENTS_API_BASE_URL, - 'splitSdkMachineName': None, - 'splitSdkMachineIp': None, - 'readTimeout': 1500, - 'featuresRefreshRate': 5, - 'segmentsRefreshRate': 60, - 'metricsRefreshRate': 60, - 'impressionsRefreshRate': 60, - 'randomizeIntervals': False, - 'maxImpressionsLogSize': -1, - 'maxMetricsCallsBeforeFlush': 1000, - 'ready': 0, - 'redisHost': 'localhost', - 'redisPort': 6379, - 'redisDb': 0, - 'redisPassword': None, - 'redisSocketTimeout': None, - 'redisSocketConnectTimeout': None, - 'redisSocketKeepalive': None, - 'redisSocketKeepaliveOptions': None, - 'redisConnectionPool': None, - 'redisUnixSocketPath': None, - 'redisEncoding': 'utf-8', - 'redisEncodingErrors': 'strict', - 'redisCharset': None, - 'redisErrors': None, - 'redisDecodeResponses': False, - 'redisRetryOnTimeout': False, - 'redisSsl': False, - 'redisSslKeyfile': None, - 'redisSslCertfile': None, - 'redisSslCertReqs': None, - 'redisSslCaCerts': None, - 'redisMaxConnections': None, - 'eventsPushRate': 60, - 'eventsQueueSize': 500, -} - -MAX_INTERVAL = 180 - - -GLOBAL_KEY_PARAMETERS = { - 'sdk-language-version': SDK_VERSION, - 'instance-id': get_hostname(), - 'ip-address': get_ip(), -} - - -def set_machine_ip(machine_ip): - if machine_ip: - GLOBAL_KEY_PARAMETERS['ip-address'] = machine_ip - - -def set_machine_name(machine_name): - if machine_name: - GLOBAL_KEY_PARAMETERS['instance-id'] = machine_name - - -def parse_config_file(filename): - """Reads a Splitio JSON config file, like the following: - - { - "apiKey": "some-api-key", - "sdkApiBaseUrl": "https://sdk-loadtesting.split.io/api", - "eventsApiBaseUrl": "https://events-loadtesting.split.io/api", - "connectionTimeout": 1500, - "readTimeout": 1500, - "featuresRefreshRate": 5, - "segmentsRefreshRate": 60, - "metricsRefreshRate": 60, - "impressionsRefreshRate": 60, - "randomizeIntervals": False, - "maxImpressionsLogSize": -1, - "maxMetricsCallsBeforeFlush": 1000, - "ready": 0, - "redisFactory": "some.python.function", - "redisHost": "locahost", - "redisPort": 6379, - "redisDb": 0 - } - - :param filename: Name of the config file - :type filename: str - :return: A config dictionary - :rtype: dict - """ - config = DEFAULT_CONFIG.copy() - - try: - with open(filename) as fp: - json_config = json.load(fp) - config.update(json_config) - if 'splitSdkMachineName' in config: - set_machine_name(config['splitSdkMachineName']) - if 'splitSdkMachineIp' in config: - set_machine_ip(config['splitSdkMachineIp']) - except Exception: - logger.error('There was a problem reading the config file: %s', filename) - return DEFAULT_CONFIG.copy() - - return config - - -def import_from_string(val, setting_name): - try: - parts = val.split('.') - module_path, class_name = '.'.join(parts[:-1]), parts[-1] - module = importlib.import_module(module_path) - return getattr(module, class_name) - except (ImportError, AttributeError) as e: - raise ImportError( - "Could not import '%s' for SPLITIO setting '%s'. %s: %s." % (val, setting_name, - e.__class__.__name__, e)) diff --git a/splitio/tests/__init__.py b/splitio/engine/__init__.py similarity index 100% rename from splitio/tests/__init__.py rename to splitio/engine/__init__.py diff --git a/splitio/evaluator.py b/splitio/engine/evaluator.py similarity index 66% rename from splitio/evaluator.py rename to splitio/engine/evaluator.py index ef4fb6c5..881ffcac 100644 --- a/splitio/evaluator.py +++ b/splitio/engine/evaluator.py @@ -1,29 +1,34 @@ +"""Split evaluator module.""" import logging -from splitio.splits import ConditionType -from splitio.impressions import Label -from splitio.splitters import Splitter -from splitio.key import Key -from splitio.treatments import CONTROL +from splitio.models.grammar.condition import ConditionType +from splitio.models.impressions import Label -class Evaluator(object): +CONTROL = 'control' - def __init__(self, broker): + +class Evaluator(object): #pylint: disable=too-few-public-methods + """Split Evaluator class.""" + + def __init__(self, split_storage, segment_storage, splitter): """ Construct a Evaluator instance. - :param broker: Broker that accepts/retrieves splits, segments, events, metrics & impressions - :type broker: BaseBroker + :param split_storage: Split storage. + :type split_storage: splitio.storage.SplitStorage - :rtype: Evaluator + :param split_storage: Storage storage. + :type split_storage: splitio.storage.SegmentStorage """ self._logger = logging.getLogger(self.__class__.__name__) - self._splitter = Splitter() - self._broker = broker + self._split_storage = split_storage + self._segment_storage = segment_storage + self._splitter = splitter - def evaluate_treatment(self, feature, matching_key, bucketing_key, attributes=None): + def evaluate_treatment(self, feature, matching_key, + bucketing_key, attributes=None): """ - Evaluates the user submitted data against a feature and return the resulting treatment. + Evaluate the user submitted data against a feature and return the resulting treatment. :param feature: The feature for which to get the treatment :type feature: str @@ -45,7 +50,7 @@ def evaluate_treatment(self, feature, matching_key, bucketing_key, attributes=No _change_number = -1 # Fetching Split definition - split = self._broker.fetch_feature(feature) + split = self._split_storage.get(feature) if split is None: self._logger.warning('Unknown or invalid feature: %s', feature) @@ -57,7 +62,7 @@ def evaluate_treatment(self, feature, matching_key, bucketing_key, attributes=No label = Label.KILLED _treatment = split.default_treatment else: - treatment, label = self.get_treatment_for_split( + treatment, label = self._get_treatment_for_split( split, matching_key, bucketing_key, @@ -77,10 +82,12 @@ def evaluate_treatment(self, feature, matching_key, bucketing_key, attributes=No } } - def get_treatment_for_split(self, split, matching_key, bucketing_key, attributes=None): + def _get_treatment_for_split(self, split, matching_key, bucketing_key, attributes=None): """ - Evaluates the feature considering the conditions. If there is a match, it will return - the condition and the label. Otherwise, it will return None, None + Evaluate the feature considering the conditions. + + If there is a match, it will return the condition and the label. + Otherwise, it will return (None, None) :param split: The split for which to get the treatment :type split: Split @@ -94,13 +101,20 @@ def get_treatment_for_split(self, split, matching_key, bucketing_key, attributes :param attributes: An optional dictionary of attributes :type attributes: dict - :return: The treatment for the key and split - :rtype: object + :return: The resulting treatment and label + :rtype: tuple """ if bucketing_key is None: bucketing_key = matching_key roll_out = False + + context = { + 'segment_storage': self._segment_storage, + 'evaluator': self, + 'bucketing_key': bucketing_key + } + for condition in split.conditions: if (not roll_out and condition.condition_type == ConditionType.ROLLOUT): @@ -114,10 +128,10 @@ def get_treatment_for_split(self, split, matching_key, bucketing_key, attributes return split.default_treatment, Label.NOT_IN_SPLIT roll_out = True - condition_matches = condition.matcher.match( - Key(matching_key, bucketing_key), + condition_matches = condition.matches( + matching_key, attributes=attributes, - client=self + context=context ) if condition_matches: diff --git a/splitio/hashfns/__init__.py b/splitio/engine/hashfns/__init__.py similarity index 67% rename from splitio/hashfns/__init__.py rename to splitio/engine/hashfns/__init__.py index 6575a4bd..a04f51da 100644 --- a/splitio/hashfns/__init__.py +++ b/splitio/engine/hashfns/__init__.py @@ -1,4 +1,6 @@ """ +Hash functions module. + This module contains hash functions implemented in pure python as well as the optional import (if installed) of a C compiled murmur hash function with python bindings. @@ -6,8 +8,8 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals -from splitio.splits import HashAlgorithm -from splitio.hashfns import legacy +from splitio.models.splits import HashAlgorithm +from splitio.engine.hashfns import legacy try: # First attempt to import module with C++ core (faster) @@ -15,10 +17,10 @@ def _murmur_hash(key, seed): return mmh3cffi.hash_str(key, seed) -except: +except ImportError: # Fallback to interpreted python hash algoritm (slower) - from splitio.hashfns import murmur3py - _murmur_hash = murmur3py.murmur32_py + from splitio.engine.hashfns import murmur3py #pylint: disable=ungrouped-imports + _murmur_hash = murmur3py.murmur32_py #pylint: disable=invalid-name _HASH_ALGORITHMS = { @@ -29,8 +31,10 @@ def _murmur_hash(key, seed): def get_hash_fn(algo): """ - Return appropriate hash function for requested algorithm + Return appropriate hash function for requested algorithm. + :param algo: Algoritm to use + :type algo: int :return: Hash function :rtype: function """ diff --git a/splitio/hashfns/legacy.py b/splitio/engine/hashfns/legacy.py similarity index 57% rename from splitio/hashfns/legacy.py rename to splitio/engine/hashfns/legacy.py index 533ed3d4..1eb4397c 100644 --- a/splitio/hashfns/legacy.py +++ b/splitio/engine/hashfns/legacy.py @@ -1,8 +1,10 @@ +"""Legacy hash function module.""" from __future__ import absolute_import, division, print_function, \ unicode_literals def as_int32(value): + """Handle overflow when working with 32 lower bits of 64 bit ints.""" if not -2147483649 <= value <= 2147483648: return (value + 2147483648) % 4294967296 - 2147483648 return value @@ -10,7 +12,8 @@ def as_int32(value): def legacy_hash(key, seed): """ - Generates a hash for a key and a feature seed. + Generate a hash for a key and a feature seed. + :param key: The key for which to get the hash :type key: str :param seed: The feature seed @@ -18,9 +21,9 @@ def legacy_hash(key, seed): :return: The hash for the key and seed :rtype: int """ - h = 0 + current_hash = 0 - for c in map(ord, key): - h = as_int32(as_int32(31 * as_int32(h)) + c) + for char in map(ord, key): + current_hash = as_int32(as_int32(31 * as_int32(current_hash)) + char) - return int(as_int32(h ^ as_int32(seed))) + return int(as_int32(current_hash ^ as_int32(seed))) diff --git a/splitio/engine/hashfns/murmur3py.py b/splitio/engine/hashfns/murmur3py.py new file mode 100644 index 00000000..346a0ace --- /dev/null +++ b/splitio/engine/hashfns/murmur3py.py @@ -0,0 +1,76 @@ +"""MurmurHash3 hash module.""" + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from six.moves import range + + +def murmur32_py(key, seed=0x0): + """ + Pure python implementation of murmur32 hash. + + :param key: Key to hash + :type key: str + :param seed: Seed to use when hashing + :type seed: int + + :return: hashed value + :rtype: int + + """ + key = bytearray(key, 'utf-8') + + def fmix(current_hash): + """Mix has bytes.""" + current_hash ^= current_hash >> 16 + current_hash = (current_hash * 0x85ebca6b) & 0xFFFFFFFF + current_hash ^= current_hash >> 13 + current_hash = (current_hash * 0xc2b2ae35) & 0xFFFFFFFF + current_hash ^= current_hash >> 16 + return current_hash + + length = len(key) + nblocks = int(length/4) + + hash1 = seed & 0xFFFFFFFF + + calc1 = 0xcc9e2d51 + calc2 = 0x1b873593 + + # body + for block_start in range(0, nblocks * 4, 4): + # ??? big endian? + key1 = key[block_start + 3] << 24 | \ + key[block_start + 2] << 16 | \ + key[block_start + 1] << 8 | \ + key[block_start + 0] + + key1 = (calc1 * key1) & 0xFFFFFFFF + key1 = (key1 << 15 | key1 >> 17) & 0xFFFFFFFF # inlined ROTL32 + key1 = (calc2 * key1) & 0xFFFFFFFF + + hash1 ^= key1 + hash1 = (hash1 << 13 | hash1 >> 19) & 0xFFFFFFFF # inlined ROTL32 + hash1 = (hash1 * 5 + 0xe6546b64) & 0xFFFFFFFF + + # tail + tail_index = nblocks * 4 + key1 = 0 + tail_size = length & 3 + + if tail_size >= 3: + key1 ^= key[tail_index + 2] << 16 + if tail_size >= 2: + key1 ^= key[tail_index + 1] << 8 + if tail_size >= 1: + key1 ^= key[tail_index + 0] + + if tail_size > 0: + key1 = (key1 * calc1) & 0xFFFFFFFF + key1 = (key1 << 15 | key1 >> 17) & 0xFFFFFFFF # inlined ROTL32 + key1 = (key1 * calc2) & 0xFFFFFFFF + hash1 ^= key1 + + unsigned_val = fmix(hash1 ^ length) + return unsigned_val diff --git a/splitio/splitters.py b/splitio/engine/splitters.py similarity index 68% rename from splitio/splitters.py rename to splitio/engine/splitters.py index 2c509556..c7e585bc 100644 --- a/splitio/splitters.py +++ b/splitio/engine/splitters.py @@ -1,20 +1,18 @@ -"""A module for implementation of the Splitter engine""" +"""A module for implementation of the Splitter engine.""" from __future__ import absolute_import, division, print_function, unicode_literals -from splitio.treatments import CONTROL -from splitio.hashfns import get_hash_fn +from splitio.engine.evaluator import CONTROL +from splitio.engine.hashfns import get_hash_fn class Splitter(object): - """ - The class responsible for selecting a treatment given a key, a feature seed and condition - partitions. - """ + """Class responsible for choosing the right partition.""" + def get_treatment(self, key, seed, partitions, algo): """ - Returs a treatment for a key, a feature seed and condition partitions. It returns CONTROL - if partitions is None or empty. + Return the appropriate treatment or CONTROL if no partitions are found. + :param key: The key for which to determine the treatment :type key: str :param seed: The feature seed @@ -35,9 +33,11 @@ def get_treatment(self, key, seed, partitions, algo): partitions ) - def get_bucket(self, key, seed, algo): + @staticmethod + def get_bucket(key, seed, algo): """ - Get the bucket for a key hash + Get the bucket for a key hash. + :param key_hash: The hash for a key :type key_hash: int :return: The bucked for a hash @@ -47,10 +47,11 @@ def get_bucket(self, key, seed, algo): key_hash = hashfn(key, seed) return abs(key_hash) % 100 + 1 - def get_treatment_for_bucket(self, bucket, partitions): + @staticmethod + def get_treatment_for_bucket(bucket, partitions): """ - Gets the treatment for a given bucket and partitions. It'll return treatment for the first - partition that contains the bucket. + Get the treatment for a given bucket and partitions. + :param bucket: The bucket number generated by get_bucket :type bucket: int :param partitions: The condition partitions diff --git a/splitio/events.py b/splitio/events.py deleted file mode 100644 index 804e58ff..00000000 --- a/splitio/events.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Event DTO and Storage classes. - -The dto is implemented as a namedtuple for performance matters. -""" - -from __future__ import print_function -from collections import namedtuple -from six.moves import queue -from six import callable - - -Event = namedtuple('Event', [ - 'key', - 'trafficTypeName', - 'eventTypeId', - 'value', - 'timestamp', -]) - - -def build_bulk(event_list): - """ - Return a list of dictionaries with all the events. - - :param event_list: list of event tuples - """ - return [e._asdict() for e in event_list] - - -class InMemoryEventStorage(object): - """ - In memory storage for events. - - Supports adding and popping events. - """ - - def __init__(self, eventsQueueSize): - """ - Construct an instance. - - :param eventsQueueSize: How many events to queue before forcing a submission - """ - self._events = queue.Queue(maxsize=eventsQueueSize) - self._queue_full_hook = None - - def set_queue_full_hook(self, hook): - """ - Set a hook to be called when the queue is full. - - :param h: Hook to be called when the queue is full - """ - if callable(hook): - self._queue_full_hook = hook - - def log_event(self, event): - """ - Add an avent to storage. - - :param event: Event to be added in the storage - """ - try: - self._events.put(event, False) - return True - except queue.Full: - if self._queue_full_hook is not None and callable(self._queue_full_hook): - self._queue_full_hook() - return False - - def pop_many(self, count): - """ - Pop multiple items from the storage. - - :param count: number of items to be retrieved and removed from the queue. - """ - events = [] - while not self._events.empty() and count > 0: - events.append(self._events.get(False)) - return events diff --git a/splitio/exceptions.py b/splitio/exceptions.py index 80995361..0af817de 100644 --- a/splitio/exceptions.py +++ b/splitio/exceptions.py @@ -8,3 +8,7 @@ class TimeoutException(Exception): class NetworkingException(Exception): pass + + +class SentinelConfigurationException(Exception): + pass diff --git a/splitio/factories.py b/splitio/factories.py deleted file mode 100644 index cddc25bb..00000000 --- a/splitio/factories.py +++ /dev/null @@ -1,121 +0,0 @@ -"""A module for Split.io Factories""" -from __future__ import absolute_import, division, print_function, unicode_literals - -from splitio.clients import Client -from splitio.brokers import get_self_refreshing_broker, get_redis_broker, get_uwsgi_broker -from splitio.managers import RedisSplitManager, SelfRefreshingSplitManager, \ - LocalhostSplitManager, UWSGISplitManager -from splitio.impressions import ImpressionListenerWrapper -from . import input_validator - -import logging - - -class SplitFactory(object): - def __init__(self): - """Basic interface of a SplitFactory. Specific implementations need to override the - client and manager method. - """ - self._logger = logging.getLogger(self.__class__.__name__) - - def client(self): # pragma: no cover - """Get the split client implementation. Subclasses need to override this method. - :return: The split client implementation. - :rtype: SplitClient - """ - raise NotImplementedError() - - def manager(self): # pragma: no cover - """Get the split manager implementation. Subclasses need to override this method. - :return: The split manager implementation. - :rtype: SplitManager - """ - raise NotImplementedError() - - -class MainSplitFactory(SplitFactory): - def __init__(self, api_key, **kwargs): - super(MainSplitFactory, self).__init__() - - config = kwargs.get('config', {}) - - labels_enabled = config.get('labelsEnabled', True) - - impression_listener = ImpressionListenerWrapper(config.get('impressionListener')) if 'impressionListener' in config else None # noqa: E501,E261 - - if 'redisHost' in config or 'redisSentinels' in config: - broker = get_redis_broker(api_key, **kwargs) - self._client = Client(broker, labels_enabled, impression_listener) - self._manager = RedisSplitManager(broker) - else: - if 'uwsgiClient' in config and config['uwsgiClient']: - broker = get_uwsgi_broker(api_key, **kwargs) - self._client = Client(broker, labels_enabled, impression_listener) - self._manager = UWSGISplitManager(broker) - else: - broker = get_self_refreshing_broker(api_key, **kwargs) - self._client = Client(broker, labels_enabled, impression_listener) - self._manager = SelfRefreshingSplitManager(broker) - - def client(self): # pragma: no cover - """Get the split client implementation. Subclasses need to override this method. - :return: The split client implementation. - :rtype: SplitClient - """ - return self._client - - def manager(self): # pragma: no cover - """Get the split manager implementation. Subclasses need to override this method. - :return: The split manager implementation. - :rtype: SplitManager - """ - return self._manager - - -class LocalhostSplitFactory(SplitFactory): - def __init__(self, **kwargs): - super(LocalhostSplitFactory, self).__init__() - - if 'split_definition_file_name' in kwargs: - broker = get_self_refreshing_broker( - 'localhost', - split_definition_file_name=kwargs['split_definition_file_name'] - ) - else: - broker = get_self_refreshing_broker('localhost') - - self._client = Client(broker) - self._manager = LocalhostSplitManager(broker.get_split_fetcher()) - - def client(self): # pragma: no cover - """Get the split client implementation. - :return: The split client implementation. - :rtype: SplitClient - """ - return self._client - - def manager(self): # pragma: no cover - """Get the split manager implementation. - :return: The split manager implementation. - :rtype: SplitManager - """ - return self._manager - - -def get_factory(api_key, **kwargs): - """ - :param api_key: - :param kwargs: - :return: - """ - config = kwargs.get('config', {}) - sdk_api_base_url = kwargs.get('sdk_api_base_url', None) - if 'redisHost' not in config and 'redisSentinels' not in config \ - and input_validator.validate_factory_instantiation(api_key, config, sdk_api_base_url) \ - is False: - return None - - if api_key == 'localhost': - return LocalhostSplitFactory(**kwargs) - else: - return MainSplitFactory(api_key, **kwargs) diff --git a/splitio/hashfns/murmur3py.py b/splitio/hashfns/murmur3py.py deleted file mode 100644 index df028014..00000000 --- a/splitio/hashfns/murmur3py.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import absolute_import, division, print_function, \ - unicode_literals - - -import sys as _sys -if (_sys.version_info > (3, 0)): - def xrange(a, b, c): - return range(a, b, c) -del _sys - - -def murmur32_py(key, seed=0x0): - """ - Pure python implementation of murmur32 hash - """ - - key = bytearray(key, 'utf-8') - - def fmix(h): - h ^= h >> 16 - h = (h * 0x85ebca6b) & 0xFFFFFFFF - h ^= h >> 13 - h = (h * 0xc2b2ae35) & 0xFFFFFFFF - h ^= h >> 16 - return h - - length = len(key) - nblocks = int(length/4) - - h1 = seed & 0xFFFFFFFF - - c1 = 0xcc9e2d51 - c2 = 0x1b873593 - - # body - for block_start in xrange(0, nblocks * 4, 4): - # ??? big endian? - k1 = key[block_start + 3] << 24 | \ - key[block_start + 2] << 16 | \ - key[block_start + 1] << 8 | \ - key[block_start + 0] - - k1 = (c1 * k1) & 0xFFFFFFFF - k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF # inlined ROTL32 - k1 = (c2 * k1) & 0xFFFFFFFF - - h1 ^= k1 - h1 = (h1 << 13 | h1 >> 19) & 0xFFFFFFFF # inlined ROTL32 - h1 = (h1 * 5 + 0xe6546b64) & 0xFFFFFFFF - - # tail - tail_index = nblocks * 4 - k1 = 0 - tail_size = length & 3 - - if tail_size >= 3: - k1 ^= key[tail_index + 2] << 16 - if tail_size >= 2: - k1 ^= key[tail_index + 1] << 8 - if tail_size >= 1: - k1 ^= key[tail_index + 0] - - if tail_size > 0: - k1 = (k1 * c1) & 0xFFFFFFFF - k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF # inlined ROTL32 - k1 = (k1 * c2) & 0xFFFFFFFF - h1 ^= k1 - - unsigned_val = fmix(h1 ^ length) - return unsigned_val diff --git a/splitio/impressions.py b/splitio/impressions.py index d1502698..627171dc 100644 --- a/splitio/impressions.py +++ b/splitio/impressions.py @@ -1,515 +1,3 @@ -"""This module contains everything related to metrics""" -from __future__ import absolute_import, division, print_function, \ - unicode_literals +"""Compatibility module for impressions listener.""" -import logging -import six -import abc - -from collections import namedtuple, defaultdict -from concurrent.futures import ThreadPoolExecutor -from copy import deepcopy -from threading import RLock, Timer -from splitio.config import SDK_VERSION, GLOBAL_KEY_PARAMETERS - - -Impression = namedtuple( - 'Impression', - [ - 'matching_key', - 'feature_name', - 'treatment', - 'label', - 'change_number', - 'bucketing_key', - 'time' - ] -) - - -def build_impressions_data(impressions): - """ - Builds a list of dictionaries that can be used with the test_impressions - API endpoint from a dictionary of lists of impressions grouped by feature - name. - :param impressions: Dict of impression tuples - :type impressions: dict - :return: List of dictionaries with impressions data for each feature - :rtype: list - """ - return [ - { - 'testName': feature_name, - 'keyImpressions': [ - { - 'keyName': impression.matching_key, - 'treatment': impression.treatment, - 'time': impression.time, - 'changeNumber': impression.change_number, - 'label': impression.label, - 'bucketingKey': impression.bucketing_key - } - for impression in feature_impressions - ] - } - for feature_name, feature_impressions in six.iteritems(impressions) - if len(feature_impressions) > 0 - ] - - -class Label(object): - # Condition: Split Was Killed - # Treatment: Default treatment - # Label: killed - KILLED = 'killed' - - # Condition: No condition matched - # Treatment: Default Treatment - # Label: no condition matched - NO_CONDITION_MATCHED = 'default rule' - - # Condition: Split definition was not found - # Treatment: control - # Label: split not found - SPLIT_NOT_FOUND = 'definition not found' - - # Condition: Traffic allocation failed - # Treatment: Default Treatment - # Label: not in split - NOT_IN_SPLIT = 'not in split' - - # Condition: There was an exception - # Treatment: control - # Label: exception - EXCEPTION = 'exception' - - -class TreatmentLog(object): - def __init__(self): - self._logger = logging.getLogger(self.__class__.__name__) - self._destroyed = False - - def _log(self, impression): - """Log an impression. Implementing classes need to override this method. - :param impression: The impression class representation - :type impression: Impression - """ - pass # Do nothing - - def log(self, impression): - """Log an impression. - :param impression: The impression - :type impression: Impression - """ - if isinstance(impression, Impression): - if impression.feature_name is not None \ - and impression.matching_key is not None \ - and impression.treatment is not None \ - and impression.time > 0: - self._log(impression) - return - - def _log_impressions(self, impressions): - """Log a bulk of impressions. Implementing classes need to override this method. - :param impressions: The impressions bulk - :type impressions: list - """ - pass # Do nothing - - def log_impressions(self, impressions): - """Log impressions. - :param impressions: The impressions bulk - :type list: Impressions - """ - self._log_impressions(impressions) - return - - def destroy(self): - """ - Prevent future thread scheduling. - """ - self._destroyed = True - - -class LoggerBasedTreatmentLog(TreatmentLog): - def _log(self, impression): - """Log an impression. - :param impression: The impression class representation - :type impression: Impression - """ - if isinstance(impression, Impression): - self._logger.info( - 'feature_name = %s, matching_key = %s, treatment = %s, ' - 'time = %s, label = %s, change_number = %s, bucketing_key = %s', - impression.feature_name, - impression.matching_key, - impression.treatment, - impression.time, - impression.label, - impression.change_number, - impression.bucketing_key - ) - - def _log_impressions(self, impressions): - """Log a bulk of impressions. - :param impressions: The impressions bulk - :type impressions: list - """ - for impression in impressions: - self._log(impression) - - -class InMemoryTreatmentLog(TreatmentLog): - def __init__(self, max_count=-1, ignore_impressions=False): - """ - A thread safe impressions log implementation that stores the impressions - in memory. Access to the impressions storage is synchronized with a - re-entrant lock. - :param max_count: Max number of impressions per feature before eviction - :type max_count: int - :param ignore_impressions: Whether to ignore log requests - :type ignore_impressions: bool - """ - super(InMemoryTreatmentLog, self).__init__() - self._max_count = max_count - self._ignore_impressions = ignore_impressions - self._impressions = defaultdict(list) - self._rlock = RLock() - - @property - def ignore_impressions(self): - """ - :return: Whether to ignore log requests - :rtype: bool - """ - return self._ignore_impressions - - @ignore_impressions.setter - def ignore_impressions(self, ignore_impressions): - """Set ignore_impressions property - :param ignore_impressions: Whether to ignore log requests - :type ignore_impressions: bool - """ - self._ignore_impressions = ignore_impressions - - @property - def max_count(self): - """ - :return: Max number of stored impressions allowed - :rtype: int - """ - return self._max_count - - @max_count.setter - def max_count(self, max_count): - """Sets the max number of stored impressions allowed - :param max_count: Max number of stored impressions allowed - :type max_count: int - """ - self._max_count = max_count - - def fetch_all_and_clear(self): - """Fetch all logged impressions and clear the log. - :return: The logged impressions - :rtype: dict - """ - with self._rlock: - existing_impressions = deepcopy(self._impressions) - self._impressions = defaultdict(list) - - return existing_impressions - - def _notify_eviction(self, feature_name, feature_impressions): - """ - Notifies that the max count was reached for a feature. - This gives the opportunity to - subclasses to do something about the eviction - :param feature_name: The name of the feature - :type feature_name: str - :param feature_impressions: The evicted impressions - :type feature_impressions: list - """ - pass # Do nothing - - def _log(self, impression): - """Log an impression. - :param impression: The impression class representation - :type impression: Impression - """ - if isinstance(impression, Impression): - with self._rlock: - feature_impressions = self._impressions[impression.feature_name] - - if self._max_count < 0 or len(feature_impressions) < self._max_count: - feature_impressions.append(impression) - else: - self._logger.warning( - 'Count limit for feature treatment log. ' - 'Clearing impressions for feature.' - ) - self._impressions[impression.feature_name] = [impression] - self._notify_eviction( - impression.feature_name, - feature_impressions - ) - - def _log_impressions(self, impressions): - """Log a bulk of impressions. - :param impressions: The impressions bulk - :type impressions: list - """ - for impression in impressions: - self._log(impression) - - -class CacheBasedTreatmentLog(TreatmentLog): - def __init__(self, impressions_cache): - """A cache based impressions log implementation. - :param impressions_cache: An impressions cache - :type impressions_cache: ImpressionsCache - """ - super(CacheBasedTreatmentLog, self).__init__() - self._impressions_cache = impressions_cache - - def _log(self, impression): - """Log an impression. - :param impression: The impression class representation - :type impression: Impression - """ - self._impressions_cache.add_impressions([impression]) - - def _log_impressions(self, impressions): - """Log a bulk of impressions. - :param impressions: The impressions bulk - :type impressions: list - """ - self._impressions_cache.add_impressions(impressions) - - -class SelfUpdatingTreatmentLog(InMemoryTreatmentLog): - def __init__(self, api, interval=180, max_workers=5, max_count=-1, - ignore_impressions=False): - """ - An impressions implementation that sends the in impressions stored - periodically to the Split.io back-end. - :param api: The SDK api client - :type api: SdkApi - :param interval: Optional update interval (Default: 180s) - :type interval: int - :param max_workers: The max number of workers used to update impressions - :type max_workers: int - :param max_count: Max number of impressions per feature before eviction - :type max_count: int - :param ignore_impressions: Whether to ignore log requests - :type ignore_impressions: bool - """ - super(SelfUpdatingTreatmentLog, self).__init__( - max_count=max_count, - ignore_impressions=ignore_impressions - ) - self._api = api - self._interval = interval - self._stopped = True - self._thread_pool_executor = ThreadPoolExecutor(max_workers=max_workers) - - @property - def stopped(self): - """ - :return: Whether the update process has been stopped - :rtype: bool - """ - return self._stopped - - @stopped.setter - def stopped(self, stopped): - """ - :param stopped: Whether to stop the update process - :type stopped: bool - """ - self._stopped = stopped - - def start(self): - """Starts the update process""" - if not self._stopped: - return - - self._stopped = False - self._timer_refresh() - - def _update_evictions(self, feature_name, feature_impressions): - """ - Sends evicted impressions to the Split.io back-end. - :param feature_name: The name of the feature - :type feature_name: str - :param feature_impressions: The evicted impressions - :type feature_impressions: list - """ - try: - test_impressions_data = build_impressions_data( - {feature_name: feature_impressions} - ) - - if len(test_impressions_data) > 0: - self._api.test_impressions(test_impressions_data) - except Exception: - self._logger.error('Error updating evicted impressions') - self._stopped = True - - def _update_impressions(self): - """ - Sends the impressions stored back to the Split.io back-end - """ - try: - impressions_by_feature = self.fetch_all_and_clear() - test_impressions_data = build_impressions_data( - impressions_by_feature - ) - - if len(test_impressions_data) > 0: - self._api.test_impressions(test_impressions_data) - except Exception: - self._logger.error('Error updating impressions') - self._stopped = True - - def _notify_eviction(self, feature_name, feature_impressions): - """ - Notifies that the max count was reached for a feature. The evicted - impressions are going to be sent to the back-end. - :param feature_name: The name of the feature - :type feature_name: str - :param feature_impressions: The evicted impressions - :type feature_impressions: list - """ - if self._destroyed \ - or feature_name is None \ - or feature_impressions is None or len(feature_impressions) == 0: - return - - try: - self._thread_pool_executor.submit( - self._update_evictions, feature_name, feature_impressions - ) - except Exception: - self._logger.error('Error starting evicted impressions update thread') - - def _timer_refresh(self): - """ - Responsible for setting the periodic calls to _update_impressions using - a Timer thread. - """ - if self._destroyed: - return - - try: - self._thread_pool_executor.submit(self._update_impressions) - except Exception: - self._logger.error('Error starting impressions update thread') - - try: - if hasattr(self._interval, '__call__'): - interval = self._interval() - else: - interval = self._interval - - timer = Timer(interval, self._timer_refresh) - timer.daemon = True - timer.start() - except Exception: - self._logger.error('Error refreshing timer') - self._stopped = True - - -class AsyncTreatmentLog(TreatmentLog): - def __init__(self, delegate, max_workers=5): - """ - A treatment log implementation that uses threads to execute the - actual logging onto a delegate log to avoid blocking the caller. - :param delegate: The delegate impression log - :type delegate: ImpressionLog - :param max_workers: How many workers to use for logging - """ - super(AsyncTreatmentLog, self).__init__() - self._delegate = delegate - self._thread_pool_executor = ThreadPoolExecutor(max_workers=max_workers) - self._destroyed = False - - @property - def delegate(self): - return self._delegate - - def destroy(self): - """ - Prevent future "log" calls from scheduling threads. - If delegate has a custom destroy() method, call it. - """ - self._destroyed = True - delegate_destroy = getattr(self._delegate, 'destroy', None) - if six.callable(delegate_destroy): - self._delegate.destroy() - - def log(self, impression): - """Logs an impression asynchronously. - :param impression: The impression - :type impression: Impression - :return: int - """ - if self._destroyed: - return - - if isinstance(impression, Impression): - try: - self._thread_pool_executor.submit(self._delegate.log, impression) - except Exception: - self._logger.error('Error logging impression asynchronously') - - def log_impressions(self, impressions): - """Log a bulk of impressions. - :param impressions: The impressions bulk - :type impressions: list - """ - for impression in impressions: - self.log(impression) - - -class ImpressionListenerException(Exception): - ''' - Custom Exception for Impression Listener - ''' - pass - - -class ImpressionListenerWrapper(object): - """ - Wrapper in charge of building all the data that client would require in case - of adding some logic with the treatment and impression results. - """ - - impression_listener = None - - def __init__(self, impression_listener): - self.impression_listener = impression_listener - - def log_impression(self, impression, attributes=None): - data = {} - data['impression'] = impression - data['attributes'] = attributes - data['sdk-language-version'] = SDK_VERSION - data['instance-id'] = GLOBAL_KEY_PARAMETERS['instance-id'] - try: - self.impression_listener.log_impression(data) - except Exception: - raise ImpressionListenerException('Error in log_impression user\'s' - 'method is throwing exceptions') - - -class ImpressionListener(object): - """ - Abstract class defining the interface that concrete client must implement, - and including methods that use that interface to add client's logic for each - impression. - """ - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def log_impression(self, data): - pass +from splitio.client.listener import ImpressionListener diff --git a/splitio/key.py b/splitio/key.py index e50d43ba..ef0039e0 100644 --- a/splitio/key.py +++ b/splitio/key.py @@ -1,22 +1,3 @@ -"""A module for Split.io SDK API clients.""" -from __future__ import absolute_import, division, print_function, \ - unicode_literals +"""Compatibility module for key.""" - -class Key(object): - """Key class includes a matching key and bucketing key.""" - - def __init__(self, matching_key, bucketing_key): - """Construct a key object.""" - self._matching_key = matching_key - self._bucketing_key = bucketing_key - - @property - def matching_key(self): - """Return matching key.""" - return self._matching_key - - @property - def bucketing_key(self): - """Return bucketing key.""" - return self._bucketing_key +from splitio.client.key import Key diff --git a/splitio/managers.py b/splitio/managers.py deleted file mode 100644 index 77c6d76d..00000000 --- a/splitio/managers.py +++ /dev/null @@ -1,374 +0,0 @@ -"""A module for Split.io Managers""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from splitio.redis_support import RedisSplitCache -from splitio.splits import SplitView -from splitio.utils import bytes_to_string -from . import input_validator - - -class SplitManager(object): - def __init__(self): - """Basic interface of a SplitManager. Specific implementations need to override the - splits, split and split_names method. - """ - self._logger = logging.getLogger(self.__class__.__name__) - - def split_names(self): - """Get the name of fetched splits. - :return: A list of str - :rtype: list - """ - raise NotImplementedError() - - def splits(self): # pragma: no cover - """Get the fetched splits. Subclasses need to override this method. - :return: A List of SplitView. - :rtype: list - """ - raise NotImplementedError() - - def split(self, feature_name): # pragma: no cover - """Get the splitView of feature_name. Subclasses need to override this method. - :return: The SplitView instance. - :rtype: SplitView - """ - raise NotImplementedError() - - -class RedisSplitManager(SplitManager): - def __init__(self, redis_broker): - """A SplitManager implementation that uses Redis as its backend. - :param redis: A redis client - :type redis: StrctRedis""" - super(RedisSplitManager, self).__init__() - - split_fetcher = redis_broker.get_split_fetcher() - split_cache = split_fetcher._split_cache - - self._split_cache = split_cache - self._split_fetcher = split_fetcher - - def split_names(self): - """Get the name of fetched splits. - :return: A list of str - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - splits = self._split_cache.get_splits_keys() - split_names = [] - for split_name in splits: - split_name = bytes_to_string(split_name) - split_names.append(split_name.replace - (RedisSplitCache._KEY_TEMPLATE.format(suffix=''), '')) - - return split_names - - def splits(self): # pragma: no cover - """Get the fetched splits. Subclasses need to override this method. - :return: A List of SplitView. - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - splits = self._split_fetcher.fetch_all() - change_number = self._split_cache.get_change_number() - - split_views = [] - - for split in splits: - treatments = [] - if hasattr(split, 'conditions'): - for condition in split.conditions: - for partition in condition.partitions: - treatments.append(partition.treatment) - split_views.append(SplitView(name=split.name, traffic_type=split.traffic_type_name, - killed=split.killed, treatments=list(set(treatments)), - change_number=change_number)) - - return split_views - - def split(self, feature_name): # pragma: no cover - """Get the splitView of feature_name. Subclasses need to override this method. - :return: The SplitView instance. - :rtype: SplitView - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return None - - feature_name = input_validator.validate_manager_feature_name(feature_name) - - if feature_name is None: - return None - - split = self._split_fetcher.fetch(feature_name) - - if split is None: - return None - - change_number = self._split_cache.get_change_number() - - treatments = [] - - for condition in split.conditions: - for partition in condition.partitions: - treatments.append(partition.treatment) - - # Using sets to avoid duplicate entries - split_view = SplitView(name=split.name, traffic_type=split.traffic_type_name, - killed=split.killed, treatments=list(set(treatments)), - change_number=change_number) - return split_view - - -class UWSGISplitManager(SplitManager): - def __init__(self, broker): - """A SplitManager implementation that uses uWSGI as its backend. - :param uwsgi: A uwsgi module - :type uwsgi: module""" - super(UWSGISplitManager, self).__init__() - - split_fetcher = broker.get_split_fetcher() - split_cache = split_fetcher._split_cache - - self._split_cache = split_cache - self._split_fetcher = split_fetcher - - def split_names(self): - """Get the name of fetched splits. - :return: A list of str - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - splits = self._split_cache.get_splits_keys() - split_names = [] - for split_name in splits: - split_name = bytes_to_string(split_name) - split_names.append(split_name) - - return split_names - - def splits(self): # pragma: no cover - """Get the fetched splits. Subclasses need to override this method. - :return: A List of SplitView. - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - splits = self._split_fetcher.fetch_all() - - split_views = [] - - for split in splits: - treatments = [] - if hasattr(split, 'conditions'): - for condition in split.conditions: - for partition in condition.partitions: - treatments.append(partition.treatment) - split_views.append(SplitView(name=split.name, traffic_type=split.traffic_type_name, - killed=split.killed, treatments=list(set(treatments)), - change_number=split.change_number)) - - return split_views - - def split(self, feature_name): # pragma: no cover - """Get the splitView of feature_name. Subclasses need to override this method. - :return: The SplitView instance. - :rtype: SplitView - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return None - - feature_name = input_validator.validate_manager_feature_name(feature_name) - - if feature_name is None: - return None - - split = self._split_fetcher.fetch(feature_name) - - if split is None: - return None - - treatments = [] - - for condition in split.conditions: - for partition in condition.partitions: - treatments.append(partition.treatment) - - # Using sets on treatments to avoid duplicate entries - split_view = SplitView(name=split.name, traffic_type=split.traffic_type_name, - killed=split.killed, treatments=list(set(treatments)), - change_number=split.change_number) - return split_view - - -class SelfRefreshingSplitManager(SplitManager): - - def __init__(self, broker): - """A SplitManager implementation that uses in-memory as its backend. - :param redis: A SplitFetcher instance - :type redis: SelfRefreshingSplitFetcher""" - super(SelfRefreshingSplitManager, self).__init__() - - self._split_fetcher = broker.get_split_fetcher() - - def split_names(self): - """Get the name of fetched splits. - :return: A list of str - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - splits = self._split_fetcher.fetch_all() - split_names = [] - for split in splits: - split_names.append(split.name) - - return split_names - - def splits(self): # pragma: no cover - """Get the fetched splits. - :return: A List of SplitView. - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - change_number = self._split_fetcher.change_number - splits = self._split_fetcher.fetch_all() - - split_views = [] - - for split in splits: - treatments = [] - for condition in split.conditions: - for partition in condition.partitions: - treatments.append(partition.treatment) - split_views.append(SplitView(name=split.name, traffic_type=split.traffic_type_name, - killed=split.killed, treatments=list(set(treatments)), - change_number=change_number)) - - return split_views - - def split(self, feature_name): - """Get the splitView of feature_name. Subclasses need to override this method. - :return: The SplitView instance. - :rtype: SplitView - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return None - - feature_name = input_validator.validate_manager_feature_name(feature_name) - - if feature_name is None: - return None - - split = self._split_fetcher.fetch(feature_name) - - if split is None: - return None - - change_number = self._split_fetcher.change_number - - treatments = [] - - for condition in split.conditions: - for partition in condition.partitions: - treatments.append(partition.treatment) - - # Using sets to avoid duplicate entries - split_view = SplitView(name=split.name, traffic_type=split.traffic_type_name, - killed=split.killed, treatments=list(set(treatments)), - change_number=change_number) - return split_view - - -class LocalhostSplitManager(SplitManager): - def __init__(self, split_fetcher): - """ - Basic interface of a SplitManager. Specific implementations need to - override the splits, split and split_names method. - """ - super(LocalhostSplitManager, self).__init__() - self._split_fetcher = split_fetcher - - def split_names(self): - """ - Get the name of fetched splits. - :return: A list of str - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - splits = self._split_fetcher.fetch_all() - split_names = [] - for split in splits: - split_names.append(split.name) - - return split_names - - def splits(self): # pragma: no cover - """Get the fetched splits. - :return: A List of SplitView. - :rtype: list - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return [] - - change_number = -1 - splits = self._split_fetcher.fetch_all() - - split_views = [] - - for split in splits: - treatments = [split.default_treatment] - split_views.append(SplitView(name=split.name, traffic_type=None, killed=split.killed, - treatments=list(set(treatments)), change_number=change_number)) - - return split_views - - def split(self, feature_name): - """ - Get the splitView of feature_name. Subclasses need to override this - method. - :return: The SplitView instance. - :rtype: SplitView - """ - if self._split_fetcher._destroyed: - self._logger.error("Client has already been destroyed - no calls possible.") - return None - - feature_name = input_validator.validate_manager_feature_name(feature_name) - - split = self._split_fetcher.fetch(feature_name) - if split is None: - return None - - change_number = -1 - treatments = [split.default_treatment] - - # Using sets to avoid duplicate entries - split_view = SplitView(name=split.name, traffic_type=None, killed=split.killed, - treatments=list(set(treatments)), change_number=change_number) - return split_view diff --git a/splitio/matchers.py b/splitio/matchers.py deleted file mode 100644 index b6392e9d..00000000 --- a/splitio/matchers.py +++ /dev/null @@ -1,777 +0,0 @@ -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -import json -import re -from enum import Enum -from sys import modules - -from future.utils import python_2_unicode_compatible -from six import string_types -from splitio.transformers import AsDateHourMinuteTimestampTransformMixin, \ - AsNumberTransformMixin, AsDateTimestampTransformMixin, TransformMixin -from splitio.key import Key - -DataType = Enum('DataType', 'DATETIME NUMBER') - - -class AndCombiner(object): - """Combines the calls to all delegates match() method with a conjunction""" - def combine(self, matchers, key, attributes, client=None): - """ - Combines the calls to the delegates match() methods to produce a single - boolean response - :param matchers: List of delegate matchers - :type matchers: list - :param key: Key to match - :type key: str - :param attributes: Attributes to match - :type attributes: dict - :return: Conjunction of all matchers match() results - :rtype: bool - """ - if not matchers: - return False - - return all( - matcher.match(key, attributes, client) for matcher in matchers - ) - - @python_2_unicode_compatible - def __str__(self): - return 'and' - - -class CombiningMatcher(object): - def __init__(self, combiner, delegates): - """ - Combines the results of multiple delegate matchers using a specific - combiner to produce a single boolean result - :param combiner: The combiner to use to generate a single result from - the individual ones - :type combiner: AndCombiner - :param delegates: Delegate matchers - :type delegates: list - """ - self._combiner = combiner - self._delegates = tuple(delegates) - - def match(self, key, attributes=None, client=None): - """ - Tests whether there is a match for the given key and attributes - :param key: Key to match - :type key: str - :param attributes: Attributes to match - :type attributes: dict - :return: Whether there is a match for the given key and attributes - :rtype: bool - """ - return self._combiner.combine(self._delegates, key, attributes, client) - - @python_2_unicode_compatible - def __str__(self): - return 'if {delegates}'.format( - delegates=' '.join( - '{combiner} {matcher}'.format(combiner=self._combiner, matcher=matcher) - for matcher in self._delegates)) - - -class AllKeysMatcher(object): - """A matcher that always returns True""" - def match(self, key, attributes=None, client=None): - """ - Returns True except when the key is None - :param key: The key to match - :type key: str - :return: True except when the key is None - :rtype: bool - """ - return key is not None - - @python_2_unicode_compatible - def __str__(self): - return 'in segment all' - - -class NegatableMatcher(object): - def __init__(self, negate, delegate): - """ - A matcher that negates the result of a delegate matcher based on the - negate flag - :param negate: Whether to negate the result of the delegate matcher - :type negate: bool - :param delegate: The delegate matcher - :type delegate: Matcher - """ - self._negate = negate - self._delegate = delegate - - def match(self, key, attributes=None, client=None): - """ - Check of a match for the given key - :param key: The key to match - :type key: str - :return: True if there is a match, False otherwise - :rtype: bool - """ - result = self._delegate.match(key, attributes, client) - return result if not self._negate else not result - - @property - def negate(self): - return self._negate - - @property - def delegate(self): - return self._delegate - - @python_2_unicode_compatible - def __str__(self): - return '{negate}{delegate}'.format( - negate='not ' if self._negate else '', - delegate=self._delegate - ) - - -class AttributeMatcher(object): - def __init__(self, attribute, matcher, negate): - """ - A matcher that looks for the value of a specific attribute and passes it - to the delegate matcher to provide a result. - :param attribute: Name of the attribute - :type attribute: str - :param matcher: The delegate matcher - :type matcher: Matcher - :param negate: Whether to negate the result - :type negate: bool - """ - self._attribute = attribute - self._matcher = NegatableMatcher(negate, matcher) - - def match(self, key, attributes=None, client=None): - """ - Matches against the value of an attribute associated with the provided - key - :param key: The key to match - :type key: str - :param attributes: Dictionary of attributes to match - :type attributes: dict - :return: If negate is False, it returns the result of calling the - delegate match method on the attribute value associated with - the key. If negate is True, it returns the opposite. - :rtype: bool - """ - if self._attribute is None: - return self._matcher.match(key, attributes, client) - - if attributes is None or \ - self._attribute not in attributes or \ - attributes[self._attribute] is None: - return False - - return self._matcher.match(attributes[self._attribute]) - - @python_2_unicode_compatible - def __str__(self): - return 'key{attribute} is {matcher}'.format( - attribute='.{}'.format(self._attribute) if self._attribute is not None else '', - matcher=self._matcher) - - -class ForDataTypeMixin(object): - """ - A mixin to provide a class method called for_data_type to build the - appropriate matcher for the given data type. The class needs to define a - dictionary member named MATCHER_FOR_DATA_TYPE that matches constructors with - data types like so: - - MATCHER_FOR_DATA_TYPE = { - DataType.DATETIME: 'DateTimeBetweenMatcher', - DataType.NUMBER: 'NumberBetweenMatcher' - } - - Then, you can use the following syntax to build the appropriate constructor: - - matcher = BetweenMatcher.for_data_type(DataType.NUMBER, 5, 10) - """ - @staticmethod - def get_class(class_name): - return getattr(modules[__name__], class_name) - - @classmethod - def for_data_type(cls, data_type, *args, **kwargs): - """ - Build a matcher appropriate for the supplied data type - :param data_type: The data type for which to build a matcher - :type data_type: DataType - :param args: arguments to be passed to the actual matcher contructor - :type args: iterable - :param kwargs: keyword arguments to be passed to the actual matcher - contructor - :type kwargs: dict - :return: A matcher appropriate for the given data type - :rtype: Matcher - """ - if data_type is None: - raise ValueError('Invalid data type') - - return cls.get_class(cls.MATCHER_FOR_DATA_TYPE[data_type])(*args, **kwargs) - - -def get_matching_key(key): - """ - """ - from splitio.key import Key - if isinstance(key, Key): - return key.matching_key - else: - return key - - -class BetweenMatcher(TransformMixin, ForDataTypeMixin): - MATCHER_FOR_DATA_TYPE = { - DataType.DATETIME: 'DateTimeBetweenMatcher', - DataType.NUMBER: 'NumberBetweenMatcher' - } - - def __init__(self, start, end, data_type): - """ - A matcher that checks if a (transformed) value is between two other - values. - :param start: The start of the interval - :type start: any - :param end: The end of the interval - :type end: any - :param data_type: The data type for the values - :type data_type: DataType - """ - self._data_type = data_type - self._original_start = start - self._original_end = end - self._start = self.transform_condition_parameter(start) - self._end = self.transform_condition_parameter(end) - - @property - def start(self): - return self._start - - @property - def end(self): - return self._end - - def match(self, key, attributes=None, client=None): - """ - Returns True if the key (after being transformed by the transform_key() - method) is between start and end - :param key: The key to match - :type key: any - :return: Whether the transformed key is between start and end - :rtype: bool - """ - key = get_matching_key(key) - transformed_key = self.transform_key(key) - - if transformed_key is None: - return None - - return self._start <= transformed_key <= self._end - - @python_2_unicode_compatible - def __str__(self): - return 'between {start} and {end}'.format( - start=self._start, end=self._end - ) - - -class DateTimeBetweenMatcher(BetweenMatcher, - AsDateHourMinuteTimestampTransformMixin): - def __init__(self, start, end): - super(DateTimeBetweenMatcher, self).__init__( - start, end, DataType.DATETIME - ) - - -class NumberBetweenMatcher(BetweenMatcher, AsNumberTransformMixin): - def __init__(self, start, end): - super(NumberBetweenMatcher, self).__init__(start, end, DataType.NUMBER) - - -class CompareMixin(object): - def compare(self, key, value): - raise NotImplementedError() - - -class EqualToCompareMixin(CompareMixin): - def compare(self, key, value): - return key == value - - -class GreaterOrEqualToCompareMixin(CompareMixin): - def compare(self, key, value): - return key >= value - - -class LessThanOrEqualToCompareMixin(CompareMixin): - def compare(self, key, value): - return key <= value - - -class CompareMatcher(TransformMixin, CompareMixin): - def __init__(self, compare_to, data_type): - """ - A matcher that compares a (transformed) key to a specific value - :param compare_to: The value to match - :type compare_to: any - :param data_type: The data type to use for comparison - :type data_type: DataType - """ - self._data_type = data_type - self._original_compare_to = compare_to - self._compare_to = self.transform_condition_parameter(compare_to) - - def match(self, key, attributes=None, client=None): - """ - Compares the supplied key with the matcher's value using the compare() - method - :param key: The key to match - :type key: str - :return: The resulf of calling compare() with the key and the value - :rtype: bool - """ - key = get_matching_key(key) - transformed_key = self.transform_key(key) - - if transformed_key is None: - return None - - return self.compare(transformed_key, self._compare_to) - - -class EqualToMatcher(CompareMatcher, EqualToCompareMixin, ForDataTypeMixin): - MATCHER_FOR_DATA_TYPE = { - DataType.DATETIME: 'DateEqualToMatcher', - DataType.NUMBER: 'NumberEqualToMatcher' - } - - @python_2_unicode_compatible - def __str__(self): - return '== {compare_to}'.format(compare_to=self._compare_to) - - -class GreaterThanOrEqualToMatcher(CompareMatcher, GreaterOrEqualToCompareMixin, - ForDataTypeMixin): - MATCHER_FOR_DATA_TYPE = { - DataType.DATETIME: 'DateTimeGreaterThanOrEqualToMatcher', - DataType.NUMBER: 'NumberGreaterThanOrEqualToMatcher' - } - - @python_2_unicode_compatible - def __str__(self): - return '>= {compare_to}'.format(compare_to=self._compare_to) - - -class LessThanOrEqualToMatcher(CompareMatcher, LessThanOrEqualToCompareMixin, - ForDataTypeMixin): - MATCHER_FOR_DATA_TYPE = { - DataType.DATETIME: 'DateTimeLessThanOrEqualToMatcher', - DataType.NUMBER: 'NumberLessThanOrEqualToMatcher' - } - - @python_2_unicode_compatible - def __str__(self): - return '<= {compare_to}'.format(compare_to=self._compare_to) - - -class DateEqualToMatcher(EqualToMatcher, AsDateTimestampTransformMixin): - def __init__(self, compare_to): - super(DateEqualToMatcher, self).__init__(compare_to, DataType.DATETIME) - - -class NumberEqualToMatcher(EqualToMatcher, AsNumberTransformMixin): - def __init__(self, compare_to): - super(NumberEqualToMatcher, self).__init__(compare_to, DataType.NUMBER) - - -class DateTimeGreaterThanOrEqualToMatcher(GreaterThanOrEqualToMatcher, - AsDateHourMinuteTimestampTransformMixin): - def __init__(self, compare_to): - super(DateTimeGreaterThanOrEqualToMatcher, self).__init__( - compare_to, DataType.DATETIME - ) - - -class NumberGreaterThanOrEqualToMatcher(GreaterThanOrEqualToMatcher, - AsNumberTransformMixin): - def __init__(self, compare_to): - super(NumberGreaterThanOrEqualToMatcher, self).__init__( - compare_to, - DataType.NUMBER - ) - - -class DateTimeLessThanOrEqualToMatcher(LessThanOrEqualToMatcher, - AsDateHourMinuteTimestampTransformMixin): - def __init__(self, compare_to): - super(DateTimeLessThanOrEqualToMatcher, self).__init__( - compare_to, - DataType.DATETIME - ) - - -class NumberLessThanOrEqualToMatcher(LessThanOrEqualToMatcher, - AsNumberTransformMixin): - def __init__(self, compare_to): - super(NumberLessThanOrEqualToMatcher, self).__init__( - compare_to, - DataType.NUMBER - ) - - -class UserDefinedSegmentMatcher(object): - def __init__(self, segment): - """ - A matcher that looks if a key is contained in a segment - :param segment: The segment to match - :type segment: Segment - """ - self._segment = segment - - @property - def segment(self): - return self._segment - - def match(self, key, attributes=None, client=None): - """ - Checks if key is contained within the segment by calling contains() - :param key: The key to match - :type key: str - :return: The result of calling contains() on the segment - :rtype: bool - """ - key = get_matching_key(key) - return self._segment.contains(key) - - @python_2_unicode_compatible - def __str__(self): - return 'in segment {segment_name}'.format( - segment_name=self._segment.name - ) - - -class WhitelistMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if a key is in a whitelist - :param whitelist: A list of strings of whitelisted keys - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if a key is in the whitelist - :param key: The key to match - :type key: str - :return: True if the key is in the whitelist, False otherwise - :rtype: bool - """ - key = get_matching_key(key) - return key in self._whitelist - - @python_2_unicode_compatible - def __str__(self): - return 'in whitelist [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class StartsWithMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if a any of the strings in whitelist is a prefix - of key - :param whitelist: A list of strings that will be treated as prefixes - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if any of the strings in whitelist is a prefix of key - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - return (isinstance(key, string_types) and - any(key.startswith(s) for s in self._whitelist)) - - @python_2_unicode_compatible - def __str__(self): - return 'has one of the following prefixes [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class EndsWithMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if a any of the strings in whitelist is a suffix - of key - :param whitelist: A list of strings that will be treated as suffixes - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if any of the strings in whitelist is a suffix of key - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - return (isinstance(key, string_types) and - any(key.endswith(s) for s in self._whitelist)) - - @python_2_unicode_compatible - def __str__(self): - return 'has one of the following suffixes [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class ContainsStringMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if a any of the strings in whitelist is a is - contained in key - :param whitelist: A list of strings that will be treated as suffixes - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if any of the strings in whitelist is a suffix of key - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - return (isinstance(key, string_types) and - any(s in key for s in self._whitelist)) - - @python_2_unicode_compatible - def __str__(self): - return 'contains one of the following string: [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class ContainsAllOfSetMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if the key, treated as a set, contains all - the elements in whitelist - :param whitelist: A list of strings that will be treated as a set - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if all the strings in whitelist are in the key when treated as - a set - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - try: - setkey = set(key) - return set(self._whitelist).issubset(setkey) - except TypeError: - return False - - @python_2_unicode_compatible - def __str__(self): - return 'contains all of the following set: [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class ContainsAnyOfSetMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if the key, treated as a set, contains any - the elements in whitelist - :param whitelist: A list of strings that will be treated as a set - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if any of the strings in whitelist are in the key when treated as - a set - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - try: - setkey = set(key) - return set(self._whitelist).intersection(setkey) - except TypeError: - return False - - @python_2_unicode_compatible - def __str__(self): - return 'contains on of the following se: [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class EqualToSetMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if the key, treated as a set, is equal to the set - formed by the elements in whitelist - :param whitelist: A list of strings that will be treated as a set - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - checks if the key, treated as a set, is equal to the set formed by the - elements in whitelist - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - try: - setkey = set(key) - return set(self._whitelist) == setkey - except TypeError: - return False - - @python_2_unicode_compatible - def __str__(self): - return 'equals the following set: [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class PartOfSetMatcher(object): - def __init__(self, whitelist): - """ - A matcher that checks if the key, treated as a set, is part of the - whitelist set - :param whitelist: A list of strings that will be treated as a set - :type whitelist: list - """ - self._whitelist = frozenset(whitelist) - - def match(self, key, attributes=None, client=None): - """ - Checks if the whitelist set contains the 'key' set - :param key: The key to match - :type key: str - :return: True under the conditiones described above - :rtype: bool - """ - key = get_matching_key(key) - try: - setkey = set(key) - return len(setkey) > 0 and setkey.issubset(set(self._whitelist)) - except TypeError: - return False - - @python_2_unicode_compatible - def __str__(self): - return 'is a subset of the following set: [{whitelist}]'.format( - whitelist=','.join('"{}"'.format(item) for item in self._whitelist) - ) - - -class DependencyMatcher(object): - """ - """ - def __init__(self, dependency_matcher_data): - """ - """ - self._data = dependency_matcher_data - - def match(self, key, attributes=None, client=None): - """ - """ - matching, bucketing = (key.matching_key, key.bucketing_key) \ - if isinstance(key, Key) else (key, None) - treatment = client.evaluate_treatment( - self._data.get('split'), - matching, - bucketing, - attributes - ) - - return treatment['treatment'] in self._data.get('treatments', []) - - -class BooleanMatcher(object): - """ - """ - def __init__(self, boolean_matcher_data): - """ - """ - self._data = boolean_matcher_data - - def match(self, key, attributes=None, client=None): - """ - """ - key = get_matching_key(key) - if isinstance(key, bool): - decoded = key - elif isinstance(key, string_types): - try: - decoded = json.loads(key.lower()) - if not isinstance(decoded, bool): - return False - except ValueError: - return False - else: - return False - - return decoded == self._data - - -class RegexMatcher(object): - """ - """ - def __init__(self, regex_matcher_data): - """ - """ - self._data = regex_matcher_data - - def match(self, key, attributes=None, client=None): - """ - """ - key = get_matching_key(key) - try: - regex = re.compile(self._data) - except re.error: - return False - - try: - matches = re.search(regex, key) - return matches is not None - except TypeError: - return False diff --git a/splitio/metrics.py b/splitio/metrics.py deleted file mode 100644 index 31402746..00000000 --- a/splitio/metrics.py +++ /dev/null @@ -1,522 +0,0 @@ -"""This module contains everything related to metrics""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -import arrow - -from bisect import bisect_left -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor -from threading import RLock - -from six import iteritems - - -SDK_GET_TREATMENT = 'sdk.getTreatment' -SDK_GET_TREATMENTS = 'sdk.getTreatments' - -BUCKETS = ( - 1000, 1500, 2250, 3375, 5063, - 7594, 11391, 17086, 25629, 38443, - 57665, 86498, 129746, 194620, 291929, - 437894, 656841, 985261, 1477892, 2216838, - 3325257, 4987885, 7481828 -) -MAX_LATENCY = 7481828 - - -def get_latency_bucket_index(micros): - """Finds the bucket index for a measure latency - :param micros: Measured latency in microseconds - :type micros: int - :return: Bucket index for the given latency - :rtype: int - """ - if micros > MAX_LATENCY: - return len(BUCKETS) - 1 - - return bisect_left(BUCKETS, micros) - - -class LatencyTracker(object): - def __init__(self, latencies=None): - """An object to count latencies that fall within certain buckets. - :param latencies: Existing latency counts - :type latencies: list""" - self._latencies = latencies if latencies is not None else [0] * len(BUCKETS) - - def add_latency_millis(self, millis): - """Increments the count bucket for milliseconds latency. - :param millis: The measured latency in milliseconds - :type millis: int - """ - self._latencies[get_latency_bucket_index(millis * 1000)] += 1 - - def add_latency_micros(self, micros): - """Increments the count bucket for microsecond latency. - :param micros: The measured latency in microseconds - :type micros: int - """ - self._latencies[get_latency_bucket_index(micros)] += 1 - - def get_latencies(self): - """ - :return: The current measured latencies - :rtype: list - """ - return list(self._latencies) - - def get_latency(self, index): - """ - :param index: The bucket index - :param index: int - :return: The current measured latency for a given bucket index - :rtype: int - """ - return self._latencies[index] - - def clear(self): - """Clears the latency counts""" - self._latencies = [0] * len(BUCKETS) - - def get_bucket_for_latency_millis(self, latency): - """:param latency: The measured latency in milliseconds - :type latency: int - :return: The bucket count for the measured latency - :rtype: int - """ - return self._latencies[get_latency_bucket_index(latency * 1000)] - - def get_bucket_for_latency_micros(self, latency): - """:param latency: The measured latency in microseconds - :type latency: int - :return: The bucket count for the measured latency - :rtype: int - """ - return self._latencies[get_latency_bucket_index(latency)] - - -class Metrics(object): # pragma: no cover - def __init__(self): - self._logger = logging.getLogger(self.__class__.__name__) - - def count(self, counter, delta): - """ - Adjusts the specified counter by a given delta. This method is is non-blocking and is - guaranteed not to throw an exception - :param counter: The name of the counter to adjust - :type counter: str - :param delta: The amount ot adjust the counter by - :type delta: int""" - pass # Do nothing - - def time(self, operation, time_in_ms): - """ - Records an execution time in milliseconds for the specified named operation. This method - is non-blocking and is guaranteed not to throw an exception. - :param operation: The name of the timed operation - :type operation: str - :param time_in_ms: The time in milliseconds - :type: int - """ - pass # Do nothing - - def gauge(self, gauge, value): - """ - Records the latest fixed value for the specified named gauge. This method is - non-blocking and is guaranteed not to throw an exception. - :param gauge: The name of the gauge - :type gauge: str - :param value: The new reading of the gauge - :type: float - """ - pass - - def destroy(self): - """ - Dummy method for dummy implementation. - """ - pass - - -class InMemoryMetrics(Metrics): - def __init__(self, count_metrics=None, time_metrics=None, gauge_metrics=None, max_call_count=-1, - max_time_between_calls=-1): - """ - A metrics implementation that stores them in memory and keeps track of calls. When too many - calls have been made consecutively or when too much time has passed between calls, the - appropriate update callback is called. Sub-classes implement these callbacks to react - accordingly. - :param count_metrics: Optional existing count metrics - :type count_metrics: defaultdict - :param time_metrics: Optional existing time metrics - :type time_metrics: defaultdict - :param gauge_metrics: Optional existing gauge metrics - :type gauge_metrics: defaultdict - :param max_call_count: How many calls before triggering an update - :type max_call_count: int - :param max_time_between_calls: How much time to wait between calls to trigger an update - :type max_time_between_calls: int - """ - super(InMemoryMetrics, self).__init__() - self._count_metrics = count_metrics if count_metrics is not None else defaultdict(int) - self._time_metrics = time_metrics if time_metrics is not None \ - else defaultdict(LatencyTracker) - self._gauge_metrics = gauge_metrics if gauge_metrics is not None else defaultdict(float) - self._max_call_count = max_call_count - self._max_time_between_calls = max_time_between_calls - - utcnow_timestamp = arrow.utcnow().timestamp - - self._count_call_count = 0 - self._count_last_call_time = utcnow_timestamp - self._time_call_count = 0 - self._time_last_call_time = utcnow_timestamp - self._gauge_call_count = 0 - self._gauge_last_call_time = utcnow_timestamp - self._count_rlock = RLock() - self._time_rlock = RLock() - self._gauge_rlock = RLock() - self._ignore_metrics = False - - @property - def ignore_metrics(self): - return self._ignore_metrics - - @ignore_metrics.setter - def ignore_metrics(self, ignore_metrics): - self._ignore_metrics = ignore_metrics - - def _fetch_count_metrics_and_clear(self): - """Returns the existing count metrics and clears the information. - :return: Existing count metrics - :rtype: dict - """ - with self._count_rlock: - count_metrics = self._count_metrics - self._count_metrics = defaultdict(int) - - return count_metrics - - def _fetch_time_metrics_and_clear(self): - """Returns the existing time metrics and clears the information. - :return: Existing time metrics - :rtype: dict - """ - with self._time_rlock: - time_metrics = self._time_metrics - self._time_metrics = defaultdict(LatencyTracker) - - return time_metrics - - def _fetch_gauge_metrics_and_clear(self): - """Returns the existing gauge metrics and clears the information. - :return: Existing gauge metrics - :rtype: dict - """ - with self._gauge_rlock: - gauge_metrics = self._gauge_metrics - self._gauge_metrics = defaultdict(int) - - return gauge_metrics - - def count(self, counter, delta): - """Adjusts the specified counter by a given delta. This method is is non-blocking and is - guaranteed not to throw an exception - :param counter: The name of the counter to adjust - :type counter: str - :param delta: The amount ot adjust the counter by - :type delta: int""" - if self.ignore_metrics: - return - - with self._count_rlock: - self._count_metrics[counter] += delta - self._count_call_count += 1 - - old_call_time = self._count_last_call_time - self._count_last_call_time = arrow.utcnow().timestamp - if (self._count_call_count == self._max_call_count > 0) or \ - self._count_last_call_time - old_call_time > self._max_time_between_calls > 0: - self._count_call_count = 0 - self.update_count() - - def update_count(self): - """Signals that an update on count metrics should be sent to the Split.io back-end""" - pass # Do nothing - - def time(self, operation, time_in_ms): - """Records an execution time in milliseconds for the specified named operation. This method - is non-blocking and is guaranteed not to throw an exception. - :param operation: The name of the timed operation - :type operation: str - :param time_in_ms: The time in milliseconds - :type: int - """ - if self.ignore_metrics: - return - - with self._time_rlock: - self._time_metrics[operation].add_latency_millis(time_in_ms) - self._time_call_count += 1 - - old_call_time = self._time_last_call_time - self._time_last_call_time = arrow.utcnow().timestamp - if (self._time_call_count == self._max_call_count > 0) or \ - self._time_last_call_time - old_call_time > self._max_time_between_calls > 0: - self._time_call_count = 0 - self.update_time() - - def update_time(self): - """Signals that an update on time metrics should be sent to the Split.io back-end""" - pass # Do nothing - - def gauge(self, gauge, value): - """Records the latest fixed value for the specified named gauge. This method is - non-blocking and is guaranteed not to throw an exception. - :param gauge: The name of the gauge - :type gauge: str - :param value: The new reading of the gauge - :type: float - """ - if self.ignore_metrics: - return - - with self._gauge_rlock: - self._gauge_metrics[gauge] = value - self._gauge_call_count += 1 - - old_call_time = self._gauge_last_call_time - self._gauge_last_call_time = arrow.utcnow().timestamp - if (self._gauge_call_count == self._max_call_count > 0) or \ - self._gauge_last_call_time - old_call_time > self._max_time_between_calls > 0: - self._gauge_call_count = 0 - self.update_gauge() - - def update_gauge(self): - """Signals that an update on time metrics should be sent to the Split.io back-end""" - pass # Do nothing - - -def build_metrics_counter_data(count_metrics): - """Convert count metrics information to the format used by the API. - :param count_metrics: A dictionary with the count metrics data - :type count_metrics: dict - :return: List of count metrics in the API format - :rtype: list - """ - return [{'name': name, 'delta': delta} for name, delta in iteritems(count_metrics)] - - -def build_metrics_times_data(time_metrics): - """Convert times metrics information to the format used by the API. - :param time_metrics: A dictionary with the times metrics data - :type time_metrics: dict - :return: List of times metrics in the API format - :rtype: list - """ - return [{'name': name, 'latencies': latencies.get_latencies()} - for name, latencies in iteritems(time_metrics)] - - -def build_metrics_gauge_data(gauge_metrics): - """Convert gauge metrics information to the format used by the API. - :param gauge_metrics: A dictionary with the gauge metrics data - :type gauge_metrics: dict - :return: List of gauge metrics in the API format - :rtype: list - """ - return [{'name': name, 'value': value} for name, value in iteritems(gauge_metrics)] - - -class ApiMetrics(InMemoryMetrics): - def __init__(self, api, max_workers=5, count_metrics=None, time_metrics=None, - gauge_metrics=None, max_call_count=-1, max_time_between_calls=-1): - """ - A metrics implementation that stores them in memory and sends them back to the Split.io - back-end when an update is triggered. - :param api: The SDK API client - :type api: ApiSdk - :param max_workers: How many workers to use in the update thread pool executor - :type max_workers: int - :param count_metrics: Optional existing count metrics - :type count_metrics: defaultdict - :param time_metrics: Optional existing time metrics - :type time_metrics: defaultdict - :param gauge_metrics: Optional existing gauge metrics - :type gauge_metrics: defaultdict - :param max_call_count: How many calls before triggering an update - :type max_call_count: int - :param max_time_between_calls: How much time to wait between calls to trigger an update - :type max_time_between_calls: int - """ - super(ApiMetrics, self).__init__(count_metrics=count_metrics, time_metrics=time_metrics, - gauge_metrics=gauge_metrics, max_call_count=max_call_count, - max_time_between_calls=max_time_between_calls) - self._api = api - self._thread_pool_executor = ThreadPoolExecutor(max_workers=max_workers) - - def _update_count_fn(self): - count_metrics = self._fetch_count_metrics_and_clear() - - try: - self._api.metrics_counters(build_metrics_counter_data(count_metrics)) - except: - self._logger.error('Error sending count metrics to the back-end. Ignoring metrics.') - self._ignore_metrics = True - - def update_count(self): - """Signals that an update on count metrics should be sent to the Split.io back-end""" - try: - self._thread_pool_executor.submit(self._update_count_fn) - except: - self._logger.error('Error submitting count metrics update task.') - - def _update_time_fn(self): - time_metrics = self._fetch_time_metrics_and_clear() - - try: - self._api.metrics_times(build_metrics_times_data(time_metrics)) - except: - self._logger.error('Error sending time metrics to the back-end. Ignoring metrics.') - self._ignore_metrics = True - - def update_time(self): - """Signals that an update on time metrics should be sent to the Split.io back-end""" - try: - self._thread_pool_executor.submit(self._update_time_fn) - except: - self._logger.error('Error submitting time metrics update task.') - - def _update_gauge_fn(self): - gauge_metrics = self._fetch_gauge_metrics_and_clear() - - try: - self._api.metrics_gauge(build_metrics_gauge_data(gauge_metrics)) - except: - self._logger.error('Error sending gauge metrics to the back-end. ' - 'Ignoring metrics.') - self._ignore_metrics = True - - def update_gauge(self): - """Signals that an update on time metrics should be sent to the Split.io back-end""" - try: - self._thread_pool_executor.submit(self._update_gauge_fn) - except: - self._logger.error('Error submitting gauge metrics update task.') - - -class LoggerMetrics(InMemoryMetrics): - def __init__(self, max_call_count=-1, max_time_between_calls=-1): - """ - A metrics implementation that stores them in memory and logs update attempts. - :param max_call_count: How many calls before triggering an update - :type max_call_count: int - :param max_time_between_calls: How much time to wait between calls to trigger an update - :type max_time_between_calls: int - """ - super(LoggerMetrics, self).__init__(max_call_count=max_call_count, - max_time_between_calls=max_time_between_calls) - - def update_count(self): - """Logs a count update request""" - count_metrics = self._fetch_count_metrics_and_clear() - self._logger.info('update_count. count_metrics = %s', - build_metrics_counter_data(count_metrics)) - - def update_time(self): - """Logs a time update request""" - time_metrics = self._fetch_time_metrics_and_clear() - self._logger.info('update_time. time_metrics = %s', build_metrics_times_data(time_metrics)) - - def update_gauge(self): - """Logs a gauge update request""" - gauge_metrics = self._fetch_gauge_metrics_and_clear() - self._logger.info('update_gauge. gauge_metrics = %s', - build_metrics_gauge_data(gauge_metrics)) - - -class AsyncMetrics(Metrics): - def __init__(self, delegate, max_workers=5): - """A non-blocking Metrics implementation that offloads calls to a delegate Metrics object - through a ThreadPoolExecutor - :param delegate: The delegate Metrics object - :type delegate: Metrics - :param max_workers: The max number of workers to use in the thread pool - :type max_workers: int - """ - super(AsyncMetrics, self).__init__() - self._delegate = delegate - self._thread_pool_executor = ThreadPoolExecutor(max_workers=max_workers) - self._destroyed = False - - def destroy(self): - self._destroyed = True - - def count(self, counter, delta): - """Adjusts the specified counter by a given delta. This method is is non-blocking and is - guaranteed not to throw an exception - :param counter: The name of the counter to adjust - :type counter: str - :param delta: The amount ot adjust the counter by - :type delta: int""" - if self._destroyed: - return - - try: - self._thread_pool_executor.submit(self._delegate.count, counter, delta) - except: - self._logger.error('Error submitting count metric') - - def time(self, operation, time_in_ms): - """Records an execution time in milliseconds for the specified named operation. This method - is non-blocking and is guaranteed not to throw an exception. - :param operation: The name of the timed operation - :type operation: str - :param time_in_ms: The time in milliseconds - :type: int - """ - if self._destroyed: - return - - try: - self._thread_pool_executor.submit(self._delegate.time, operation, time_in_ms) - except: - self._logger.error('Error submitting time metric') - - def gauge(self, gauge, value): - """Records the latest fixed value for the specified named gauge. This method is - non-blocking and is guaranteed not to throw an exception. - :param gauge: The name of the gauge - :type gauge: str - :param value: The new reading of the gauge - :type: float - """ - if self._destroyed: - return - - try: - self._thread_pool_executor.submit(self._delegate.gauge, gauge, value) - except: - self._logger.error('Error submitting gauge metric') - - -class CacheBasedMetrics(Metrics): - def __init__(self, metrics_cache): - """A Metrics implementation that uses a MetricsCache to keep track of the metrics - information. - :param metrics_cache: The metrics cache - :type metrics_cache: MetricsCache - """ - self._metrics_cache = metrics_cache - - def gauge(self, gauge, value): - self._metrics_cache.set_gague(gauge, value) - - def time(self, operation, time_in_ms): - self._metrics_cache.increment_latency_bucket_counter(operation, - get_latency_bucket_index(time_in_ms)) - - def count(self, counter, delta): - self._metrics_cache.increment_count(counter, delta) diff --git a/splitio/update_scripts/__init__.py b/splitio/models/__init__.py similarity index 100% rename from splitio/update_scripts/__init__.py rename to splitio/models/__init__.py diff --git a/splitio/models/datatypes.py b/splitio/models/datatypes.py new file mode 100644 index 00000000..7cbe466a --- /dev/null +++ b/splitio/models/datatypes.py @@ -0,0 +1,61 @@ +"""Datatypes converters for matchers.""" + +def ts_truncate_seconds(timestamp): + """ + Set seconds to zero in a timestamp. + + :param ts: Timestamp in seconds. + :type ts: int + + :return: Timestamp in seconds, but without counting them (ie: DD-MM-YY HH:MM:00) + :rtype: int + """ + return timestamp - (timestamp % 60) + +def ts_truncate_time(timestamp): + """ + Set time to zero in a timestamp. + + :param ts: Timestamp in seconds. + :type ts: int + + :return: Timestamp in seconds, without counting time (ie: DD-MM-YYYY 00:00:00) + :rtype: int + """ + return timestamp - (timestamp % 86400) + +def java_ts_to_secs(java_ts): + """ + Convert java timestamp into unix timestamp. + + :param java_ts: java timestamp in milliseconds. + :type java_ts: int + + :return: Timestamp in seconds. + :rtype: int + """ + return java_ts / 1000 + +def java_ts_truncate_seconds(java_ts): + """ + Set seconds to zero in a timestamp. + + :param ts: Timestamp in seconds. + :type ts: int + + :return: Timestamp in seconds, but without counting them (ie: DD-MM-YY HH:MM:00) + :rtype: int + """ + return ts_truncate_seconds(java_ts_to_secs(java_ts)) + +def java_ts_truncate_time(java_ts): + """ + Set time to zero in a timestamp. + + :param ts: Timestamp in seconds. + :type ts: int + + :return: Timestamp in seconds, without counting time (ie: DD-MM-YYYY 00:00:00) + :rtype: int + """ + return ts_truncate_time(java_ts_to_secs(java_ts)) diff --git a/splitio/models/events.py b/splitio/models/events.py new file mode 100644 index 00000000..4519aebe --- /dev/null +++ b/splitio/models/events.py @@ -0,0 +1,17 @@ +""" +Event DTO and Storage classes. + +The dto is implemented as a namedtuple for performance matters. +""" + +from __future__ import print_function +from collections import namedtuple + + +Event = namedtuple('Event', [ + 'key', + 'traffic_type_name', + 'event_type_id', + 'value', + 'timestamp', +]) diff --git a/splitio/models/grammar/__init__.py b/splitio/models/grammar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/models/grammar/condition.py b/splitio/models/grammar/condition.py new file mode 100644 index 00000000..961d5f54 --- /dev/null +++ b/splitio/models/grammar/condition.py @@ -0,0 +1,135 @@ +"""Split conditions module.""" + +from enum import Enum +from future.utils import python_2_unicode_compatible +import six + +from splitio.models.grammar import matchers +from splitio.models.grammar import partitions + +_MATCHER_COMBINERS = { + 'AND': lambda ms, k, a, c: all(m.evaluate(k, a, c) for m in ms) +} + + +class ConditionType(Enum): + """Split possible condition types.""" + + WHITELIST = 'WHITELIST' + ROLLOUT = 'ROLLOUT' + + +class Condition(object): + """Condition object class.""" + + def __init__( #pylint: disable=too-many-arguments + self, + matcher_list, + combiner, parts, label, + condition_type=ConditionType.WHITELIST + ): + """ + Class constructor. + + :param matcher: A combining matcher + :type matcher: CombiningMatcher + :param parts: A list of partitions + :type parts: list + """ + self._matchers = matcher_list + self._combiner = combiner + self._partitions = tuple(parts) + self._label = label + self._condition_type = condition_type + + @property + def matchers(self): + """Return the list of matchers associated to the condition.""" + return self._matchers + + @property + def partitions(self): + """Return the list of partitions associated with the condition.""" + return self._partitions + + @property + def label(self): + """Return the label of this condition.""" + return self._label + + @property + def condition_type(self): + """Return the condition type.""" + return self._condition_type + + def matches(self, key, attributes=None, context=None): + """ + Check whether the condition matches against user submitted input. + + :param key: User key + :type key: splitio.client.key.Key + :param attributes: User custom attributes. + :type attributes: dict + :param context: Evaluation context + :type context: dict + """ + return self._combiner(self._matchers, key, attributes, context) + + def get_segment_names(self): + """ + Fetch segment names for all IN_SEGMENT matchers. + + :return: List of segment names + :rtype: list(str) + """ + return [ + matcher._segment_name for matcher in self.matchers #pylint: disable=protected-access + if isinstance(matcher, matchers.UserDefinedSegmentMatcher) + ] + + @python_2_unicode_compatible + def __str__(self): + """Return the string representation of the condition.""" + return '{matcher} then split {parts}'.format( + matcher=self._matchers, parts=','.join( + '{size}:{treatment}'.format(size=partition.size, + treatment=partition.treatment) + for partition in self._partitions)) + + def to_json(self): + """Return the JSON representation of this condition.""" + return { + 'conditionType': self._condition_type.name, + 'label': self._label, + 'matcherGroup': { + 'combiner': next( + (k, v) for k, v in six.iteritems(_MATCHER_COMBINERS) if v == self._combiner + )[0], + 'matchers': [m.to_json() for m in self.matchers] + }, + 'partitions': [p.to_json() for p in self.partitions] + } + + +def from_raw(raw_condition): + """ + Parse a condition from a JSON portion of splitChanges. + + :param raw_condition: JSON object extracted from a split's conditions array. + :type raw_condition: dict + + :return: A condition object. + :rtype: Condition + """ + parsed_partitions = [ + partitions.from_raw(raw_partition) + for raw_partition in raw_condition['partitions'] + ] + + matcher_objects = [matchers.from_raw(x) for x in raw_condition['matcherGroup']['matchers']] + combiner = _MATCHER_COMBINERS[raw_condition['matcherGroup']['combiner']] + label = raw_condition.get('label') + + condition_type = ConditionType(raw_condition.get('conditionType', ConditionType.WHITELIST)) + + return Condition(matcher_objects, combiner, parsed_partitions, label, condition_type) diff --git a/splitio/models/grammar/matchers/__init__.py b/splitio/models/grammar/matchers/__init__.py new file mode 100644 index 00000000..f61eb6be --- /dev/null +++ b/splitio/models/grammar/matchers/__init__.py @@ -0,0 +1,70 @@ +"""Matchers entrypoint module.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from splitio.models.grammar.matchers.keys import AllKeysMatcher, UserDefinedSegmentMatcher +from splitio.models.grammar.matchers.numeric import BetweenMatcher, EqualToMatcher, \ + GreaterThanOrEqualMatcher, LessThanOrEqualMatcher +from splitio.models.grammar.matchers.sets import ContainsAllOfSetMatcher, \ + ContainsAnyOfSetMatcher, EqualToSetMatcher, PartOfSetMatcher +from splitio.models.grammar.matchers.string import ContainsStringMatcher, \ + EndsWithMatcher, RegexMatcher, StartsWithMatcher, WhitelistMatcher +from splitio.models.grammar.matchers.misc import BooleanMatcher, DependencyMatcher + + +MATCHER_TYPE_ALL_KEYS = 'ALL_KEYS' +MATCHER_TYPE_IN_SEGMENT = 'IN_SEGMENT' +MATCHER_TYPE_WHITELIST = 'WHITELIST' +MATCHER_TYPE_EQUAL_TO = 'EQUAL_TO' +MATCHER_TYPE_GREATER_THAN_OR_EQUAL_TO = 'GREATER_THAN_OR_EQUAL_TO' +MATCHER_TYPE_LESS_THAN_OR_EQUAL_TO = 'LESS_THAN_OR_EQUAL_TO' +MATCHER_TYPE_BETWEEN = 'BETWEEN' +MATCHER_TYPE_EQUAL_TO_SET = 'EQUAL_TO_SET' +MATCHER_TYPE_PART_OF_SET = 'PART_OF_SET' +MATCHER_TYPE_CONTAINS_ALL_OF_SET = 'CONTAINS_ALL_OF_SET' +MATCHER_TYPE_CONTAINS_ANY_OF_SET = 'CONTAINS_ANY_OF_SET' +MATCHER_TYPE_STARTS_WITH = 'STARTS_WITH' +MATCHER_TYPE_ENDS_WITH = 'ENDS_WITH' +MATCHER_TYPE_CONTAINS_STRING = 'CONTAINS_STRING' +MATCHER_TYPE_IN_SPLIT_TREATMENT = 'IN_SPLIT_TREATMENT' +MATCHER_TYPE_EQUAL_TO_BOOLEAN = 'EQUAL_TO_BOOLEAN' +MATCHER_TYPE_MATCHES_STRING = 'MATCHES_STRING' + + +_MATCHER_BUILDERS = { + MATCHER_TYPE_ALL_KEYS: AllKeysMatcher, + MATCHER_TYPE_IN_SEGMENT: UserDefinedSegmentMatcher, + MATCHER_TYPE_WHITELIST: WhitelistMatcher, + MATCHER_TYPE_EQUAL_TO: EqualToMatcher, + MATCHER_TYPE_GREATER_THAN_OR_EQUAL_TO: GreaterThanOrEqualMatcher, + MATCHER_TYPE_LESS_THAN_OR_EQUAL_TO: LessThanOrEqualMatcher, + MATCHER_TYPE_BETWEEN: BetweenMatcher, + MATCHER_TYPE_EQUAL_TO_SET: EqualToSetMatcher, + MATCHER_TYPE_PART_OF_SET: PartOfSetMatcher, + MATCHER_TYPE_CONTAINS_ALL_OF_SET: ContainsAllOfSetMatcher, + MATCHER_TYPE_CONTAINS_ANY_OF_SET: ContainsAnyOfSetMatcher, + MATCHER_TYPE_STARTS_WITH: StartsWithMatcher, + MATCHER_TYPE_ENDS_WITH: EndsWithMatcher, + MATCHER_TYPE_CONTAINS_STRING: ContainsStringMatcher, + MATCHER_TYPE_IN_SPLIT_TREATMENT: DependencyMatcher, + MATCHER_TYPE_EQUAL_TO_BOOLEAN: BooleanMatcher, + MATCHER_TYPE_MATCHES_STRING: RegexMatcher +} + + +def from_raw(raw_matcher): + """ + Parse a condition from a JSON portion of splitChanges. + + :param raw_matcher: JSON object extracted from a condition's matcher array. + :type raw_matcher: dict + + :return: A concrete Matcher object. + :rtype: Matcher + """ + matcher_type = raw_matcher['matcherType'] + try: + builder = _MATCHER_BUILDERS[matcher_type] + except KeyError: + raise ValueError('Invalid matcher type %s' % matcher_type) + return builder(raw_matcher) diff --git a/splitio/models/grammar/matchers/base.py b/splitio/models/grammar/matchers/base.py new file mode 100644 index 00000000..656c88c3 --- /dev/null +++ b/splitio/models/grammar/matchers/base.py @@ -0,0 +1,122 @@ +"""Abstract matcher module.""" +import abc +from splitio.client.key import Key + + +class Matcher(object): + """Matcher abstract class.""" + + __metaclass__ = abc.ABCMeta + + def __init__(self, raw_matcher): + """ + Initialize generic data and call matcher-specific parser. + + :param raw_matcher: raw matcher as read from splitChanges response. + :type raw_matcher: dict + + :returns: A concrete matcher object. + :rtype: Matcher + """ + self._negate = raw_matcher['negate'] + self._matcher_type = raw_matcher['matcherType'] + key_selector = raw_matcher.get('keySelector') + if key_selector is not None and 'attribute' in key_selector: + self._attribute_name = raw_matcher['keySelector']['attribute'] + else: + self._attribute_name = None + self._build(raw_matcher) + + def _get_matcher_input(self, key, attributes=None): + """ + Examine split, attributes & key, and return the appropriate matching input. + + :param key: User-submitted key + :type key: str | Key + :param attributes: User-submitted attributes + :type attributes: dict + + :returns: data to use when matching + :rtype: str | set | int | bool + """ + if self._attribute_name is not None: + if attributes is not None and attributes.get(self._attribute_name) is not None: + return attributes[self._attribute_name] + return None + + if isinstance(key, Key): + return key.matching_key + + return key + + @abc.abstractmethod + def _build(self, raw_matcher): + """ + Build the final matcher according to matcher specific data. + + :param raw_matcher: raw matcher as read from splitChanges response. + :type raw_matcher: dict + """ + pass + + @abc.abstractmethod + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + pass + + def evaluate(self, key, attributes=None, context=None): + """ + Perform the actual evaluation taking into account possible matcher negation. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + """ + return self._negate ^ self._match(key, attributes, context) + + @abc.abstractmethod + def _add_matcher_specific_properties_to_json(self): + """ + Add matcher specific properties to base dict before returning it. + + :return: Dictionary with matcher specific prooperties. + :rtype: dict + """ + pass + + def to_json(self): + """ + Reconstruct the original JSON representation of the matcher. + + :return: JSON representation of a matcher. + :rtype: dict + """ + base = { + "keySelector": {'attribute': self._attribute_name} if self._attribute_name else None, + "matcherType": self._matcher_type, + "negate": self._negate, + "userDefinedSegmentMatcherData": None, + "whitelistMatcherData": None, + "unaryNumericMatcherData": None, + "betweenMatcherData": None, + "dependencyMatcherData": None, + "booleanMatcherData": None, + "stringMatcherData": None, + } + base.update(self._add_matcher_specific_properties_to_json()) + return base diff --git a/splitio/models/grammar/matchers/keys.py b/splitio/models/grammar/matchers/keys.py new file mode 100644 index 00000000..95741b72 --- /dev/null +++ b/splitio/models/grammar/matchers/keys.py @@ -0,0 +1,98 @@ +"""Keys matchers module.""" + +from future.utils import python_2_unicode_compatible +from splitio.models.grammar.matchers.base import Matcher + +class AllKeysMatcher(Matcher): + """A matcher that always returns True.""" + + def _build(self, raw_matcher): + """ + Build an AllKeysMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + pass + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + return key is not None + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'in segment all' + + def _add_matcher_specific_properties_to_json(self): + """Add matcher specific properties to base dict before returning it.""" + return {} + + + + + + + + +class UserDefinedSegmentMatcher(Matcher): + """Matcher that returns true when the submitted key belongs to a segment.""" + + def _build(self, raw_matcher): + """ + Build an UserDefinedSegmentMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._segment_name = raw_matcher['userDefinedSegmentMatcherData']['segmentName'] + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + segment_storage = context.get('segment_storage') + if not segment_storage: + raise Exception('Segment storage not present in matcher context.') + + matching_data = self._get_matcher_input(key, attributes) + if matching_data is None: + return False + return segment_storage.segment_contains(self._segment_name, matching_data) + + def _add_matcher_specific_properties_to_json(self): + """Return UserDefinedSegment specific properties.""" + return { + 'userDefinedSegmentMatcherData': { + 'segmentName': self._segment_name + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'in segment {segment_name}'.format( + segment_name=self._segment_name + ) diff --git a/splitio/models/grammar/matchers/misc.py b/splitio/models/grammar/matchers/misc.py new file mode 100644 index 00000000..00d5f31f --- /dev/null +++ b/splitio/models/grammar/matchers/misc.py @@ -0,0 +1,102 @@ +"""Miscelaneous matchers that don't fall into other categories.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import json +from future.utils import python_2_unicode_compatible +from six import string_types + +from splitio.models.grammar.matchers.base import Matcher + + +class DependencyMatcher(Matcher): + """Matcher that returns true if the user's key secondary evaluation result matches.""" + + def _build(self, raw_matcher): + """ + Build an DependencyMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._split_name = raw_matcher['dependencyMatcherData']['split'] + self._treatments = raw_matcher['dependencyMatcherData']['treatments'] + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + evaluator = context.get('evaluator') + assert evaluator is not None + + bucketing_key = context.get('bucketing_key') + + result = evaluator.evaluate_treatment(self._split_name, key, bucketing_key, attributes) + return result['treatment'] in self._treatments + + def _add_matcher_specific_properties_to_json(self): + """Return Dependency specific properties.""" + return { + 'dependencyMatcherData': { + 'split': self._split_name, + 'treatments': self._treatments + } + } + + +class BooleanMatcher(Matcher): + """Matcher that returns true if the user submited value is similar to the stored boolean.""" + + def _build(self, raw_matcher): + """ + Build an BooleanMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher['booleanMatcherData'] + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = self._get_matcher_input(key, attributes) + if matching_data is None: + return False + if isinstance(matching_data, bool): + decoded = matching_data + elif isinstance(matching_data, string_types): + try: + decoded = json.loads(matching_data.lower()) + if not isinstance(decoded, bool): + return False + except ValueError: + return False + else: + return False + + return decoded == self._data + + def _add_matcher_specific_properties_to_json(self): + """Return Boolean specific properties.""" + return {'booleanMatcherData': self._data} diff --git a/splitio/models/grammar/matchers/numeric.py b/splitio/models/grammar/matchers/numeric.py new file mode 100644 index 00000000..7c06ef3b --- /dev/null +++ b/splitio/models/grammar/matchers/numeric.py @@ -0,0 +1,256 @@ +"""Numeric & Date based matchers.""" +import numbers + +import logging +from future.utils import python_2_unicode_compatible +from six import string_types + +from splitio.models.grammar.matchers.base import Matcher +from splitio.models import datatypes + + +class Sanitizer(object): # pylint: disable=too-few-public-methods + """Numeric input sanitizer.""" + + _logger = logging.getLogger('InputSanitizer') + + @classmethod + def ensure_int(cls, data): + """ + Do a best effort attempt to conver input to a int. + + :param input: user supplied input. + :type input: mixed. + """ + if data is None: # Failed to fetch attribute. no need to convert. + return None + + # For some reason bool is considered an integral type. We want to avoid True + # to be converted to 1, and False to 0 on numeric matchers since it can be + # misleading. + if isinstance(data, numbers.Integral) and not isinstance(data, bool): + return data + + if not isinstance(data, string_types): + cls._logger.error('Cannot convert %s to int. Failing.', type(data)) + return None + + + cls._logger.warning( + 'Supplied attribute is of type %s and should have been an int. ', + type(data) + ) + + try: + return int(data) + except ValueError: + cls._logger.error('Cannot convert %s to int. Failing.', type(data)) + return None + + +class ZeroSecondDataMatcher(object): #pylint: disable=too-few-public-methods + """Mixin to use in matchers that when dealing with datetimes, truncate seconds.""" + + data_parsers = { + 'NUMBER': lambda x: x, + 'DATETIME': datatypes.java_ts_truncate_seconds + } + + input_parsers = { + 'NUMBER': lambda x: x, + 'DATETIME': datatypes.ts_truncate_seconds + } + + +class ZeroTimeDataMatcher(object): #pylint: disable=no-init,too-few-public-methods + """Mixin to use in matchers that when dealing with datetimes, truncate time.""" + + input_parsers = { + 'NUMBER': lambda x: x, + 'DATETIME': datatypes.ts_truncate_time + } + + data_parsers = { + 'NUMBER': lambda x: x, + 'DATETIME': datatypes.java_ts_truncate_time + } + + +class BetweenMatcher(Matcher, ZeroSecondDataMatcher): + """Matcher that returns true if user input is within a specified range.""" + + def _build(self, raw_matcher): + """ + Build InBetweenMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data_type = raw_matcher['betweenMatcherData']['dataType'] + self._original_lower = raw_matcher['betweenMatcherData']['start'] + self._original_upper = raw_matcher['betweenMatcherData']['end'] + self._lower = self.data_parsers[self._data_type](self._original_lower) + self._upper = self.data_parsers[self._data_type](self._original_upper) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return self._lower <= self.input_parsers[self._data_type](matching_data) <= self._upper + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'between {start} and {end}'.format(start=self._lower, end=self._upper) + + def _add_matcher_specific_properties_to_json(self): + """Return BetweenMatcher specific properties.""" + return { + 'betweenMatcherData': { + 'dataType': self._data_type, + 'start': self._original_lower, + 'end': self._original_upper + } + } + + +class EqualToMatcher(Matcher, ZeroTimeDataMatcher): + """Return true if the provided input is equal to the value stored in the matcher.""" + + def _build(self, raw_matcher): + """ + Build EqualToMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data_type = raw_matcher['unaryNumericMatcherData']['dataType'] + self._original_value = raw_matcher['unaryNumericMatcherData']['value'] + self._value = self.data_parsers[self._data_type](self._original_value) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return self.input_parsers[self._data_type](matching_data) == self._value + + def _add_matcher_specific_properties_to_json(self): + """Return EqualTo specific properties.""" + return { + 'unaryNumericMatcherData': { + 'dataType': self._data_type, + 'value': self._original_value, + } + } + + +class GreaterThanOrEqualMatcher(Matcher, ZeroSecondDataMatcher): + """Return true if the provided input is >= the value stored in the matcher.""" + + def _build(self, raw_matcher): + """ + Build GreaterThanOrEqualMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data_type = raw_matcher['unaryNumericMatcherData']['dataType'] + self._original_value = raw_matcher['unaryNumericMatcherData']['value'] + self._value = self.data_parsers[self._data_type](self._original_value) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return self.input_parsers[self._data_type](matching_data) >= self._value + + def _add_matcher_specific_properties_to_json(self): + """Return GreaterThan specific properties.""" + return { + 'unaryNumericMatcherData': { + 'dataType': self._data_type, + 'value': self._original_value, + } + } + + +class LessThanOrEqualMatcher(Matcher, ZeroSecondDataMatcher): + """Return true if the provided input is <= the value stored in the matcher.""" + + def _build(self, raw_matcher): + """ + Build LessThanOrEqualMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data_type = raw_matcher['unaryNumericMatcherData']['dataType'] + self._original_value = raw_matcher['unaryNumericMatcherData']['value'] + self._value = self.data_parsers[self._data_type](self._original_value) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_int(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return self.input_parsers[self._data_type](matching_data) <= self._value + + def _add_matcher_specific_properties_to_json(self): + """Return LessThan specific properties.""" + return { + 'unaryNumericMatcherData': { + 'dataType': self._data_type, + 'value': self._original_value, + } + } diff --git a/splitio/models/grammar/matchers/sets.py b/splitio/models/grammar/matchers/sets.py new file mode 100644 index 00000000..4d10c4bd --- /dev/null +++ b/splitio/models/grammar/matchers/sets.py @@ -0,0 +1,208 @@ +"""Set based matchers module.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals +from future.utils import python_2_unicode_compatible + +from splitio.models.grammar.matchers.base import Matcher + + +class ContainsAllOfSetMatcher(Matcher): + """Matcher that returns true if the user data is a subset of the matcher's data.""" + + def _build(self, raw_matcher): + """ + Build an ContainsAllOfSetMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = self._get_matcher_input(key, attributes) + if matching_data is None: + return False + try: + setkey = set(matching_data) + return self._whitelist.issubset(setkey) + except TypeError: + return False + + def _add_matcher_specific_properties_to_json(self): + """Return ContainsAllOfSet specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'contains all of the following set: [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class ContainsAnyOfSetMatcher(Matcher): + """Matcher that returns true if the intersection of both sets is not empty.""" + + def _build(self, raw_matcher): + """ + Build an ContainsAnyOfSetMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = self._get_matcher_input(key, attributes) + if matching_data is None: + return False + try: + return len(self._whitelist.intersection(set(matching_data))) != 0 + except TypeError: + return False + + def _add_matcher_specific_properties_to_json(self): + """Return ContainsAnyOfSet specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'contains on of the following se: [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class EqualToSetMatcher(Matcher): + """Matcher that returns true if the set provided by the user is equal to the matcher's one.""" + + def _build(self, raw_matcher): + """ + Build an EqualToSetMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = self._get_matcher_input(key, attributes) + if matching_data is None: + return False + try: + return self._whitelist == set(matching_data) + except TypeError: + return False + + def _add_matcher_specific_properties_to_json(self): + """Return EqualToSet specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'equals the following set: [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class PartOfSetMatcher(Matcher): + """a.""" + + def _build(self, raw_matcher): + """ + Build an PartOfSetMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = self._get_matcher_input(key, attributes) + if matching_data is None: + return False + try: + setkey = set(matching_data) + return len(setkey) > 0 and setkey.issubset(set(self._whitelist)) + except TypeError: + return False + + def _add_matcher_specific_properties_to_json(self): + """Return PartOfSet specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'is a subset of the following set: [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) diff --git a/splitio/models/grammar/matchers/string.py b/splitio/models/grammar/matchers/string.py new file mode 100644 index 00000000..9688aacc --- /dev/null +++ b/splitio/models/grammar/matchers/string.py @@ -0,0 +1,275 @@ +"""String matchers module.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import logging +import json +import re +from future.utils import python_2_unicode_compatible +from six import string_types + +from splitio.models.grammar.matchers.base import Matcher + + +class Sanitizer(object): # pylint: disable=too-few-public-methods + """Numeric input sanitizer.""" + + _logger = logging.getLogger('InputSanitizer') + + @classmethod + def ensure_string(cls, data): + """ + Do a best effort attempt to conver input to a string. + + :param input: user supplied input. + :type input: mixed. + + :return: String or None + :rtype: string + """ + if data is None: # Failed to fetch attribute. no need to convert. + return None + + if isinstance(data, string_types): + return data + + cls._logger.warning( + 'Supplied attribute is of type %s and should have been a string. ', + type(data) + ) + try: + return json.dumps(data) + except TypeError: + return None + + +class WhitelistMatcher(Matcher): + """Matcher that returns true if the user key is within a whitelist.""" + + def _build(self, raw_matcher): + """ + Build an WhitelistMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return matching_data in self._whitelist + + def _add_matcher_specific_properties_to_json(self): + """Return Whitelist specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': list(self._whitelist) + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'in whitelist [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class StartsWithMatcher(Matcher): + """Matcher that returns true if the key is a prefix of the stored value.""" + + def _build(self, raw_matcher): + """ + Build an StartsWithMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return (isinstance(key, string_types) and + any(matching_data.startswith(s) for s in self._whitelist)) + + def _add_matcher_specific_properties_to_json(self): + """Return StartsWith specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'has one of the following prefixes [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class EndsWithMatcher(Matcher): + """Matcher that returns true if the key ends with the suffix stored in matcher data.""" + + def _build(self, raw_matcher): + """ + Build an EndsWithMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return (isinstance(key, string_types) and + any(matching_data.endswith(s) for s in self._whitelist)) + + def _add_matcher_specific_properties_to_json(self): + """Return EndsWith specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'has one of the following suffixes [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class ContainsStringMatcher(Matcher): + """Matcher that returns true if the input key is part of the string in matcher data.""" + + def _build(self, raw_matcher): + """ + Build a ContainsStringMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._whitelist = frozenset(raw_matcher['whitelistMatcherData']['whitelist']) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + return (isinstance(matching_data, string_types) and + any(s in matching_data for s in self._whitelist)) + + def _add_matcher_specific_properties_to_json(self): + """Return ContainsString specific properties.""" + return { + 'whitelistMatcherData': { + 'whitelist': self._whitelist + } + } + + @python_2_unicode_compatible + def __str__(self): + """Return string Representation.""" + return 'contains one of the following string: [{whitelist}]'.format( + whitelist=','.join('"{}"'.format(item) for item in self._whitelist) + ) + + +class RegexMatcher(Matcher): + """Matcher that returns true if the user input matches the regex stored in the matcher.""" + + def _build(self, raw_matcher): + """ + Build a RegexMatcher. + + :param raw_matcher: raw matcher as fetched from splitChanges response. + :type raw_matcher: dict + """ + self._data = raw_matcher['stringMatcherData'] + self._regex = re.compile(self._data) + + def _match(self, key, attributes=None, context=None): + """ + Evaluate user input against a matcher and return whether the match is successful. + + :param key: User key. + :type key: str. + :param attributes: Custom user attributes. + :type attributes: dict. + :param context: Evaluation context + :type context: dict + + :returns: Wheter the match is successful. + :rtype: bool + """ + matching_data = Sanitizer.ensure_string(self._get_matcher_input(key, attributes)) + if matching_data is None: + return False + try: + matches = re.search(self._regex, matching_data) + return matches is not None + except TypeError: + return False + + def _add_matcher_specific_properties_to_json(self): + """Return Regex specific properties.""" + return {'stringMatcherData': self._data} diff --git a/splitio/models/grammar/partitions.py b/splitio/models/grammar/partitions.py new file mode 100644 index 00000000..e38d5d98 --- /dev/null +++ b/splitio/models/grammar/partitions.py @@ -0,0 +1,58 @@ +"""Split partition module.""" + +from future.utils import python_2_unicode_compatible + + +class Partition(object): + """Partition object class.""" + + def __init__(self, treatment, size): + """ + Class constructor. + + :param treatment: The treatment for the partition + :type treatment: str + :param size: A number between 0 a 100 + :type size: float + """ + if size < 0 or size > 100: + raise ValueError('size MUST BE between 0 and 100') + + self._treatment = treatment + self._size = size + + @property + def treatment(self): + """Return the treatment associated with this partition.""" + return self._treatment + + @property + def size(self): + """Return the percentage owned by this partition.""" + return self._size + + def to_json(self): + """Return a JSON representation of a partition.""" + return { + 'treatment': self._treatment, + 'size': self._size + } + + @python_2_unicode_compatible + def __str__(self): + """Return string representation of a partition.""" + return '{size}%:{treatment}'.format(size=self._size, + treatment=self._treatment) + + +def from_raw(raw_partition): + """ + Build a partition object from a splitChanges partition portion. + + :param raw_partition: JSON snippet of a partition. + :type raw_partition: dict + + :return: New partition object. + :rtype: Partition + """ + return Partition(raw_partition['treatment'], raw_partition['size']) diff --git a/splitio/models/impressions.py b/splitio/models/impressions.py new file mode 100644 index 00000000..e7288a78 --- /dev/null +++ b/splitio/models/impressions.py @@ -0,0 +1,48 @@ +"""Impressions model module.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from collections import namedtuple + + +Impression = namedtuple( + 'Impression', + [ + 'matching_key', + 'feature_name', + 'treatment', + 'label', + 'change_number', + 'bucketing_key', + 'time' + ] +) + + +class Label(object): #pylint: disable=too-few-public-methods + """Impressions labels.""" + + # Condition: Split Was Killed + # Treatment: Default treatment + # Label: killed + KILLED = 'killed' + + # Condition: No condition matched + # Treatment: Default Treatment + # Label: no condition matched + NO_CONDITION_MATCHED = 'default rule' + + # Condition: Split definition was not found + # Treatment: control + # Label: split not found + SPLIT_NOT_FOUND = 'definition not found' + + # Condition: Traffic allocation failed + # Treatment: Default Treatment + # Label: not in split + NOT_IN_SPLIT = 'not in split' + + # Condition: There was an exception + # Treatment: control + # Label: exception + EXCEPTION = 'exception' diff --git a/splitio/models/segments.py b/splitio/models/segments.py new file mode 100644 index 00000000..3d7ab5c9 --- /dev/null +++ b/splitio/models/segments.py @@ -0,0 +1,86 @@ +"""Segment module.""" + + +class Segment(object): + """Segment object class.""" + + def __init__(self, name, keys, change_number): + """ + Class constructor. + + :param name: Segment name. + :type name: str + + :param keys: List of keys belonging to the segment. + :type keys: List + """ + self._name = name + self._keys = set(keys) + self._change_number = change_number + + @property + def name(self): + """Return segment name.""" + return self._name + + def contains(self, key): + """ + Return whether the supplied key belongs to the segment. + + :param key: User key. + :type key: str + + :return: True if the user is in the segment. False otherwise. + :rtype: bool + """ + return key in self._keys + + def update(self, to_add, to_remove): + """ + Add supplied keys to the segment. + + :param to_add: List of keys to add. + :type to_add: list + :param to_remove: List of keys to remove. + :type to_remove: list + """ + self._keys = self._keys.union(set(to_add)).difference(to_remove) + + @property + def keys(self): + """ + Return the segment keys. + + :return: A set of the segment keys + :rtype: set + """ + return self._keys + + @property + def change_number(self): + """Return segment change number.""" + return self._change_number + + @change_number.setter + def change_number(self, new_value): + """ + Set new change number. + + :param new_value: New change number. + :type new_value: int + """ + self._change_number = new_value + + +def from_raw(raw_segment): + """ + Parse a new segment from a raw segment_changes response. + + :param raw_segment: Segment parsed from segment changes response. + :type raw_segment: dict + + :return: New segment model object + :rtype: splitio.models.segment.Segment + """ + keys = set(raw_segment['added']).difference(raw_segment['removed']) + return Segment(raw_segment['name'], keys, raw_segment['till']) diff --git a/splitio/models/splits.py b/splitio/models/splits.py new file mode 100644 index 00000000..170fe285 --- /dev/null +++ b/splitio/models/splits.py @@ -0,0 +1,223 @@ +"""Splits module.""" +from __future__ import absolute_import, division, print_function, unicode_literals + +from enum import Enum +from collections import namedtuple +from future.utils import python_2_unicode_compatible + +from splitio.models.grammar import condition + + +SplitView = namedtuple( + 'SplitView', + ['name', 'traffic_type', 'killed', 'treatments', 'change_number'] +) + + +class Status(Enum): + """Split status.""" + + ACTIVE = "ACTIVE" + ARCHIVED = "ARCHIVED" + + +class HashAlgorithm(Enum): + """Hash algorithm names.""" + + LEGACY = 1 + MURMUR = 2 + + +class Split(object): #pylint: disable=too-many-instance-attributes + """Split model object.""" + + def __init__( #pylint: disable=too-many-arguments + self, + name, + seed, + killed, + default_treatment, + traffic_type_name, + status, + change_number, + conditions=None, + algo=None, + traffic_allocation=None, + traffic_allocation_seed=None + ): + """ + Class constructor. + + :param name: Name of the feature + :type name: unicode + :param seed: Seed + :type seed: int + :param killed: Whether the split is killed or not + :type killed: bool + :param default_treatment: Default treatment for the split + :type default_treatment: str + :param conditions: Set of conditions to test + :type conditions: list + :param algo: Hash algorithm to use when splitting. + :type algo: HashAlgorithm + :param traffic_allocation: Percentage of traffic to consider. + :type traffic_allocation: int + :pram traffic_allocation_seed: Seed used to hash traffic allocation. + :type traffic_allocation_seed: int + """ + self._name = name + self._seed = seed + self._killed = killed + self._default_treatment = default_treatment + self._traffic_type_name = traffic_type_name + try: + self._status = Status(status) + except ValueError: + self._status = Status.ARCHIVED + + self._change_number = change_number + self._conditions = conditions if conditions is not None else [] + + if traffic_allocation is None: + self._traffic_allocation = 100 + elif traffic_allocation >= 0 and traffic_allocation <= 100: + self._traffic_allocation = traffic_allocation + else: + self._traffic_allocation = 100 + + self._traffic_allocation_seed = traffic_allocation_seed + try: + self._algo = HashAlgorithm(algo) + except ValueError: + self._algo = HashAlgorithm.LEGACY + + @property + def name(self): + """Return name.""" + return self._name + + @property + def seed(self): + """Return seed.""" + return self._seed + + @property + def algo(self): + """Return hash algorithm.""" + return self._algo + + @property + def killed(self): + """Return whether the split has been killed.""" + return self._killed + + @property + def default_treatment(self): + """Return the default treatment.""" + return self._default_treatment + + @property + def traffic_type_name(self): + """Return the traffic type of the split.""" + return self._traffic_type_name + + @property + def status(self): + """Return the status of the split.""" + return self._status + + @property + def change_number(self): + """Return the change number of the split.""" + return self._change_number + + @property + def conditions(self): + """Return the condition list of the split.""" + return self._conditions + + @property + def traffic_allocation(self): + """Return the traffic allocation percentage of the split.""" + return self._traffic_allocation + + @property + def traffic_allocation_seed(self): + """Return the traffic allocation seed of the split.""" + return self._traffic_allocation_seed + + def get_segment_names(self): + """ + Return a list of segment names referenced in all matchers from this split. + + :return: List of segment names. + :rtype: list(string) + """ + return [name for cond in self.conditions for name in cond.get_segment_names()] + + def to_json(self): + """Return a JSON representation of this split.""" + return { + 'changeNumber': self.change_number, + 'trafficTypeName': self.traffic_type_name, + 'name': self.name, + 'trafficAllocation': self.traffic_allocation, + 'trafficAllocationSeed': self.traffic_allocation_seed, + 'seed': self.seed, + 'status': self.status.value, + 'killed': self.killed, + 'defaultTreatment': self.default_treatment, + 'algo': self.algo.value, + 'conditions': [c.to_json() for c in self.conditions] + } + + def to_split_view(self): + """ + Return a SplitView for the manager. + + :return: A portion of the split useful for inspecting by the user. + :rtype: SplitView + """ + return SplitView( + self.name, + self.traffic_type_name, + self.killed, + list(set(part.treatment for cond in self.conditions for part in cond.partitions)), + self.change_number + ) + + @python_2_unicode_compatible + def __str__(self): + """Return string representation.""" + return 'name: {name}, seed: {seed}, killed: {killed}, ' \ + 'default treatment: {default_treatment}, ' \ + 'conditions: {conditions}'.format( + name=self._name, seed=self._seed, killed=self._killed, + default_treatment=self._default_treatment, + conditions=','.join(map(str, self._conditions)) + ) + + +def from_raw(raw_split): + """ + Parse a split from a JSON portion of splitChanges. + + :param raw_split: JSON object extracted from a splitChange's split array (splitChanges response) + :type raw_split: dict + + :return: A parsed Split object capable of performing evaluations. + :rtype: Split + """ + return Split( + raw_split['name'], + raw_split['seed'], + raw_split['killed'], + raw_split['defaultTreatment'], + raw_split['trafficTypeName'], + raw_split['status'], + raw_split['changeNumber'], + [condition.from_raw(c) for c in raw_split['conditions']], + raw_split.get('algo'), + traffic_allocation=raw_split.get('trafficAllocation'), + traffic_allocation_seed=raw_split.get('trafficAllocationSeed') + ) diff --git a/splitio/models/telemetry.py b/splitio/models/telemetry.py new file mode 100644 index 00000000..e4739328 --- /dev/null +++ b/splitio/models/telemetry.py @@ -0,0 +1,27 @@ +"""SDK Telemetry helpers.""" +from bisect import bisect_left + + +BUCKETS = ( + 1000, 1500, 2250, 3375, 5063, + 7594, 11391, 17086, 25629, 38443, + 57665, 86498, 129746, 194620, 291929, + 437894, 656841, 985261, 1477892, 2216838, + 3325257, 4987885, 7481828 +) +MAX_LATENCY = 7481828 + + +def get_latency_bucket_index(micros): + """ + Find the bucket index for a measured latency. + + :param micros: Measured latency in microseconds + :type micros: int + :return: Bucket index for the given latency + :rtype: int + """ + if micros > MAX_LATENCY: + return len(BUCKETS) - 1 + + return bisect_left(BUCKETS, micros) diff --git a/splitio/prefix_decorator.py b/splitio/prefix_decorator.py deleted file mode 100644 index 851edee7..00000000 --- a/splitio/prefix_decorator.py +++ /dev/null @@ -1,137 +0,0 @@ -from six import string_types, binary_type - - -class PrefixDecorator: - ''' - Instance decorator for Redis clients such as StrictRedis. - Adds an extra layer handling addition/removal of user prefix when handling - keys - ''' - - def __init__(self, decorated, prefix=None): - ''' - Stores the user prefix and the redis client instance. - - :param decorated: Instance of redis cache client to decorate. - :param prefix: User prefix to add. - ''' - self._prefix = prefix - self._decorated = decorated - - def _add_prefix(self, k): - ''' - Add a prefix to the contents of k. - 'k' may be: - - a single key (of type string or unicode in python2, or type string - in python 3. In which case we simple add a prefix with a dot. - - a list, in which the prefix is applied to element. - If no user prefix is stored, the key/list of keys will be returned as is - - :param k: single (string) or list of (list) keys. - :returns: Key(s) with prefix if applicable - ''' - if self._prefix: - if isinstance(k, string_types): - return '{prefix}.{key}'.format(prefix=self._prefix, key=k) - elif isinstance(k, list) and len(k) > 0: - if isinstance(k[0], binary_type): - return [ - '{prefix}.{key}'.format(prefix=self._prefix, key=key.decode("utf8")) - for key in k - ] - elif isinstance(k[0], string_types): - return [ - '{prefix}.{key}'.format(prefix=self._prefix, key=key.decode("utf8")) - for key in k - ] - - else: - return k - - def _remove_prefix(self, k): - ''' - Removes the user prefix from a key before handling it back - to the requester. - Similar to _add_prefix, this class will handle single strings as well - as lists. If no _prefix is set, the original key/keys will be returned. - - :param k: key(s) whose prefix will be removed. - :returns: prefix-less key(s) - ''' - if self._prefix: - if isinstance(k, string_types): - return k[len(self._prefix)+1:] - elif isinstance(k, list): - return [key[len(self._prefix)+1:] for key in k] - else: - return k - - # Below starts a list of methods that implement the interface of a standard - # redis client. - - def keys(self, pattern): - return self._remove_prefix( - self._decorated.keys(self._add_prefix(pattern)) - ) - - def set(self, name, value, *args, **kwargs): - return self._decorated.set( - self._add_prefix(name), value, *args, **kwargs - ) - - def get(self, name): - return self._decorated.get(self._add_prefix(name)) - - def setex(self, name, time, value): - return self._decorated.setex(self._add_prefix(name), time, value) - - def delete(self, names): - return self._decorated.delete(self._add_prefix(names)) - - def exists(self, name): - return self._decorated.exists(self._add_prefix(name)) - - def mget(self, names): - return self._decorated.mget(self._add_prefix(names)) - - def smembers(self, name): - return self._decorated.smembers(self._add_prefix(name)) - - def sadd(self, name, *values): - return self._decorated.sadd(self._add_prefix(name), *values) - - def srem(self, name, *values): - return self._decorated.srem(self._add_prefix(name), *values) - - def sismember(self, name, value): - return self._decorated.sismember(self._add_prefix(name), value) - - def eval(self, *args): - script = args[0] - num_keys = args[1] - keys = list(args[2:]) - return self._decorated.eval(script, num_keys, *self._add_prefix(keys)) - - def hset(self, name, key, value): - return self._decorated.hset(self._add_prefix(name), key, value) - - def hget(self, name, key): - return self._decorated.hget(self._add_prefix(name), key) - - def incr(self, name, amount=1): - return self._decorated.incr(self._add_prefix(name), amount) - - def getset(self, name, value): - return self._decorated.getset(self._add_prefix(name), value) - - def rpush(self, key, values): - return self._decorated.rpush(self._add_prefix(key), *values) - - def expire(self, key, value): - return self._decorated.expire(self._add_prefix(key), value) - - def rpop(self, key): - return self._decorated.rpop(self._add_prefix(key)) - - def ttl(self, key): - return self._decorated.ttl(self._add_prefix(key)) diff --git a/splitio/redis_support.py b/splitio/redis_support.py deleted file mode 100644 index 0d9dfd7d..00000000 --- a/splitio/redis_support.py +++ /dev/null @@ -1,937 +0,0 @@ -'''This module contains everything related to redis cache implementations''' -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -import re -import logging - -from builtins import zip -from itertools import groupby, islice - -try: - from jsonpickle import decode, encode - from redis import StrictRedis - from redis.sentinel import Sentinel -except ImportError: - def missing_redis_dependencies(*args, **kwargs): - raise NotImplementedError('Missing Redis support dependencies.') - decode = encode = StrictRedis = missing_redis_dependencies - -from six import iteritems - -from splitio.config import import_from_string, GLOBAL_KEY_PARAMETERS -from splitio.cache import SegmentCache, SplitCache, ImpressionsCache, \ - MetricsCache -from splitio.matchers import UserDefinedSegmentMatcher -from splitio.metrics import BUCKETS -from splitio.segments import Segment -from splitio.splits import Split, SplitParser -from splitio.impressions import Impression -from splitio.utils import bytes_to_string -from splitio.prefix_decorator import PrefixDecorator -# Template for Split.io related Cache keys -_SPLITIO_CACHE_KEY_TEMPLATE = 'SPLITIO.{suffix}' - -IMPRESSIONS_QUEUE_KEY = 'SPLITIO.impressions' -IMPRESSION_KEY_DEFAULT_TTL = 3600 - - -class SentinelConfigurationException(Exception): - pass - - -class RedisSegmentCache(SegmentCache): - ''' - ''' - _KEY_TEMPLATE = _SPLITIO_CACHE_KEY_TEMPLATE.format( - suffix='segments.{suffix}' - ) - _DISABLED_KEY = _KEY_TEMPLATE.format(suffix='__disabled__') - _SEGMENT_KEY_SET_KEY_TEMPLATE = _SPLITIO_CACHE_KEY_TEMPLATE.format( - suffix='segment.{segment_name}' - ) - _SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE = _SPLITIO_CACHE_KEY_TEMPLATE.format( - suffix='segment.{segment_name}.till' - ) - - def __init__(self, redis): - ''' - A Segment Cache implementation that uses Redis as its back-end - :param redis: The redis client - :rtype redis: StringRedis - ''' - self._redis = redis - - def register_segment(self, segment_name): - ''' - Register a segment for inclusion in the automatic update process. - :param segment_name: Name of the segment. - :type segment_name: str - ''' - # self._redis.sadd( - # RedisSegmentCache._KEY_TEMPLATE.format(suffix='registered'), - # segment_name - # ) - # @TODO The Segment logic for redis should be removed. - pass - - def unregister_segment(self, segment_name): - ''' - Unregister a segment from the automatic update process. - :param segment_name: Name of the segment. - :type segment_name: str - ''' - # self._redis.srem( - # RedisSegmentCache._KEY_TEMPLATE.format(suffix='registered'), - # segment_name - # ) - # @TODO The Segment logic for redis should be removed. - pass - - def get_registered_segments(self): - ''' - :return: All segments included in the automatic update process. - :rtype: set - ''' - return self._redis.smembers( - RedisSegmentCache._KEY_TEMPLATE.format(suffix='registered') - ) - - def _get_segment_key_set_key(self, segment_name): - ''' - Build cache key for a given segment key set. - :param segment_name: The name of the segment - :type segment_name: str - :return: The cache key for the segment key set - :rtype: str - ''' - segment_name = bytes_to_string(segment_name) - return RedisSegmentCache._SEGMENT_KEY_SET_KEY_TEMPLATE.format( - segment_name=segment_name - ) - - def _get_segment_change_number_key(self, segment_name): - ''' - Build cache key for a given segment change_number. - :param segment_name: The name of the segment - :type segment_name: str - :return: The cache key for the segment change number - :rtype: str - ''' - segment_name = bytes_to_string(segment_name) - return RedisSegmentCache._SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE.format( - segment_name=segment_name - ) - - def add_keys_to_segment(self, segment_name, segment_keys): - self._redis.sadd( - self._get_segment_key_set_key(segment_name), - *segment_keys - ) - - def remove_keys_from_segment(self, segment_name, segment_keys): - self._redis.srem( - self._get_segment_key_set_key(segment_name), - *segment_keys - ) - - def is_in_segment(self, segment_name, key): - return self._redis.sismember( - self._get_segment_key_set_key(segment_name), - key - ) - - def set_change_number(self, segment_name, change_number): - self._redis.set( - self._get_segment_change_number_key(segment_name), - change_number - ) - - def get_change_number(self, segment_name): - change_number = self._redis.get( - self._get_segment_change_number_key(segment_name) - ) - return int(change_number) if change_number is not None else -1 - - -class RedisSplitCache(SplitCache): - _KEY_TEMPLATE = _SPLITIO_CACHE_KEY_TEMPLATE.format(suffix='split.{suffix}') - _KEY_TILL_TEMPLATE = _SPLITIO_CACHE_KEY_TEMPLATE.format( - suffix='splits.{suffix}' - ) - - def __init__(self, redis): - ''' - A SplitCache implementation that uses Redis as its back-end. - :param redis: The redis client - :type redis: StrictRedis - ''' - self._redis = redis - self._logger = logging.getLogger(self.__class__.__name__) - - def get_splits_keys(self): - return self._redis.keys( - RedisSplitCache._KEY_TEMPLATE.format(suffix='*') - ) - - def _get_split_key(self, split_name): - ''' - Builds a Redis key cache for a given split (feature) name, - :param split_name: Name of the split (feature) - :type split_name: str - :return: The split key - :rtype: str - ''' - return RedisSplitCache._KEY_TEMPLATE.format(suffix=split_name) - - def get_change_number(self): - change_number = self._redis.get( - RedisSplitCache._KEY_TILL_TEMPLATE.format(suffix='till') - ) - return int(change_number) if change_number is not None else -1 - - def set_change_number(self, change_number): - self._redis.set( - RedisSplitCache._KEY_TILL_TEMPLATE.format(suffix='till'), - change_number, - None - ) - - def add_split(self, split_name, split): - self._redis.set(self._get_split_key(split_name), encode(split)) - - def get_split(self, split_name): - - to_decode = self._redis.get(self._get_split_key(split_name)) - if to_decode is None: - return None - - to_decode = bytes_to_string(to_decode) - - split_dump = decode(to_decode) - - if split_dump is not None: - segment_cache = RedisSegmentCache(self._redis) - split_parser = RedisSplitParser(segment_cache) - split = split_parser.parse(split_dump) - return split - - return None - - def get_splits(self): - keys = self.get_splits_keys() - - splits = self._redis.mget(keys) - - to_return = [] - - segment_cache = RedisSegmentCache(self._redis) - split_parser = RedisSplitParser(segment_cache) - - for split in splits: - try: - split = bytes_to_string(split) - split_dump = decode(split) - if split_dump is not None: - to_return.append(split_parser.parse(split_dump)) - except Exception: - self._logger.error( - 'Error decoding/parsing fetched split or invalid split' - ' format: %s' % split - ) - - return to_return - - def remove_split(self, split_name): - self._redis.delete(self._get_split_key(split_name)) - - -class RedisEventsCache(ImpressionsCache): - _KEY_TEMPLATE = ( - 'SPLITIO.events' - ) - - def __init__(self, redis): - ''' - An ImpressionsCache implementation that uses Redis as its back-end - :param redis: The redis client - :type redis: StrictRedis - ''' - self._logger = logging.getLogger(self.__class__.__name__) - self._redis = redis - - def log_event(self, event): - """ - Adds an event to the redis storage - """ - key = self._KEY_TEMPLATE - to_store = { - 'e': dict(event._asdict()), - 'm': { - 's': GLOBAL_KEY_PARAMETERS['sdk-language-version'], - 'n': GLOBAL_KEY_PARAMETERS['instance-id'], - 'i': GLOBAL_KEY_PARAMETERS['ip-address'], - } - } - try: - self._redis.rpush(key, [encode(to_store)]) - return True - except Exception: - self._logger.error("Something went wrong when trying to add event to redis") - return False - - -class RedisImpressionsCache(ImpressionsCache): - _KEY_TEMPLATE = ( - 'SPLITIO/{sdk-language-version}/{instance-id}/impressions.{feature}' - ) - - @classmethod - def _get_impressions_key(cls, feature_name): - ''' - ''' - return cls._KEY_TEMPLATE.format( - **dict(GLOBAL_KEY_PARAMETERS, feature=feature_name) - ) - - @classmethod - def _get_impressions_clear_key(cls): - ''' - ''' - return cls._get_impressions_key('impressions_to_clear') - - def __init__(self, redis): - ''' - An ImpressionsCache implementation that uses Redis as its back-end - :param redis: The redis client - :type redis: StrictRedis - ''' - self._redis = redis - self._logger = logging.getLogger(self.__class__.__name__) - - key_params = GLOBAL_KEY_PARAMETERS.copy() - key_params['suffix'] = '{feature_name}' - - def _build_impressions_dict(self, impressions): - ''' - Buils a dictionary of impressions that groups them based on their - feature name. - :param impressions: List of impression tuples - :type impressions: list - :return: Dictionary of impressions grouped by feature name - :rtype: dict - ''' - sorted_impressions = sorted( - impressions, - key=lambda impression: impression.feature_name - ) - grouped_impressions = groupby( - sorted_impressions, - key=lambda impression: impression.feature_name - ) - return dict( - (feature_name, list(group)) - for feature_name, group in grouped_impressions - ) - - def fetch_all(self): - ''' - Fetches all impressions from the cache. It returns a dictionary with the - impressions grouped by feature name. - :return: All cached impressions so far grouped by feature name - :rtype: dict - ''' - impressions_list = list() - impressions_keys = self._redis.keys(self._get_impressions_key('*')) - - for impression_key in impressions_keys: - impression_key = bytes_to_string(impression_key) - if (impression_key.replace(self._get_impressions_key(''), '') == 'impressions'): - continue - - feature_name = impression_key.replace( - self._get_impressions_key(''), - '' - ) - - for impression in self._redis.smembers(impression_key): - impression = bytes_to_string(impression) - impression_decoded = decode(impression) - impression_tuple = Impression( - key=impression_decoded['keyName'], - feature_name=feature_name, - treatment=impression_decoded['treatment'], - time=impression_decoded['time'] - ) - impressions_list.append(impression_tuple) - - if not impressions_list: - return dict() - - return self._build_impressions_dict(impressions_list) - - def clear(self): - ''' - Clears all cached impressions - ''' - self._redis.eval( - "return redis.call('del', unpack(redis.call('keys', ARGV[1])))", - 0, - self._get_impressions_key('*') - ) - - def add_impressions(self, impressions): - ''' - Adds impressions to the queue if it is enabled, otherwise the impressions - are dropped. - :param impressions: The impression bulk - :type impressions: list - ''' - bulk_impressions = [] - for impression in impressions: - if isinstance(impression, Impression): - to_store = { - 'm': { # METADATA PORTION - 's': GLOBAL_KEY_PARAMETERS['sdk-language-version'], - 'n': GLOBAL_KEY_PARAMETERS['instance-id'], - 'i': GLOBAL_KEY_PARAMETERS['ip-address'], - }, - 'i': { # IMPRESSION PORTION - 'k': impression.matching_key, - 'b': impression.bucketing_key, - 'f': impression.feature_name, - 't': impression.treatment, - 'r': impression.label, - 'c': impression.change_number, - 'm': impression.time, - } - } - bulk_impressions.append(encode(to_store)) - try: - inserted = self._redis.rpush(IMPRESSIONS_QUEUE_KEY, bulk_impressions) - if inserted == len(bulk_impressions): - self._logger.debug("SET EXPIRE KEY FOR QUEUE") - self._redis.expire(IMPRESSIONS_QUEUE_KEY, IMPRESSION_KEY_DEFAULT_TTL) - return True - except Exception: - self._logger.error("Something went wrong when trying to add impression to redis") - return False - - def fetch_all_and_clear(self): - ''' - Fetches all impressions from the cache and clears it. - It returns a dictionary with the impressions grouped by feature name. - :return: All cached impressions so far grouped by feature name - :rtype: dict - ''' - impressions_list = list() - impressions_keys = self._redis.keys(self._get_impressions_key('*')) - - for impression_key in impressions_keys: - - impression_key = bytes_to_string(impression_key) - - if (impression_key.replace(self._get_impressions_key(''), '') == 'impressions'): - continue - - feature_name = impression_key.replace( - self._get_impressions_key(''), - '' - ) - - to_remove = list() - for impression in self._redis.smembers(impression_key): - to_remove.append(impression) - - impression = bytes_to_string(impression) - - impression_decoded = decode(impression) - - label = '' - if 'label' in impression_decoded: - label = impression_decoded['label'] - - change_number = -1 - if 'changeNumber' in impression_decoded: - change_number = impression_decoded['changeNumber'] - - bucketing_key = '' - if 'bucketingKey' in impression_decoded: - bucketing_key = impression_decoded['bucketingKey'] - - impression_tuple = Impression( - matching_key=impression_decoded['keyName'], - feature_name=feature_name, - treatment=impression_decoded['treatment'], - label=label, - change_number=change_number, - bucketing_key=bucketing_key, - time=impression_decoded['time'] - ) - impressions_list.append(impression_tuple) - - self._redis.srem(impression_key, *set(to_remove)) - - if not impressions_list: - return dict() - - return self._build_impressions_dict(impressions_list) - - -class RedisMetricsCache(MetricsCache): - _KEY_TEMPLATE = _SPLITIO_CACHE_KEY_TEMPLATE.format( - suffix='metrics.{suffix}' - ) - - _KEY_LATENCY = ( - 'SPLITIO/{sdk-language-version}/{instance-id}/' - 'latency.{metric_name}.bucket.{bucket_number}' - ) - - _KEY_COUNT = 'SPLITIO/{sdk-language-version}/{instance-id}/count.{counter}' - _KEY_GAUGE = 'SPLITIO/{sdk-language-version}/{instance-id}/gauge.{gauge}' - - @classmethod - def _get_latency_bucket_key(cls, metric_name, bucket_number): - ''' - Returns the latency bucket - ''' - return cls._KEY_LATENCY.format(**dict( - GLOBAL_KEY_PARAMETERS, - metric_name=metric_name, - bucket_number=bucket_number - )) - - @classmethod - def _get_count_key(cls, counter): - ''' - Returns the count key - ''' - return cls._KEY_COUNT.format(**dict( - GLOBAL_KEY_PARAMETERS, - counter=counter - )) - - @classmethod - def _get_gauge_key(cls, gauge): - ''' - Returns the gauge key - ''' - return cls._KEY_GAUGE.format(**dict( - GLOBAL_KEY_PARAMETERS, - gauge=gauge - )) - - @classmethod - def _get_latency_field_re(cls): - return ("^%s$" % (cls._KEY_LATENCY - .replace('.', '\.') - .replace('{metric_name}', '(?P.+)') - .replace('{bucket_number}', '(?P.+)')) - ).format(**GLOBAL_KEY_PARAMETERS) - - @classmethod - def _get_count_field_re(cls): - return ("^%s$" % (cls._KEY_COUNT - .replace('.', '\.') - .replace('{counter}', '(?P.+)')) - ).format(**GLOBAL_KEY_PARAMETERS) - - @classmethod - def _get_gauge_field_re(cls): - return ("^%s$" % (cls._KEY_GAUGE - .replace('.', '\.') - .replace('{gauge}', '(?P.+)')) - ).format(**GLOBAL_KEY_PARAMETERS) - - def __init__(self, redis): - ''' - A MetricsCache implementation that uses Redis as its back-end - :param redis: The redis client - :type redis: StrictRedis - ''' - super(RedisMetricsCache, self).__init__() - self._redis = redis - - def _get_time_field(self, operation, bucket_index): - ''' - Builds the field name for a latency counting bucket ont the metrics - redis hash. - :param operation: Name of the operation - :type operation: str - :param bucket_index: Latency bucket index as returned by - get_latency_bucket_index - :type bucket_index: int - :return: Name of the field on the metrics hash for the latency bucket - counter - :rtype: str - ''' - return RedisMetricsCache._TIME_FIELD_TEMPLATE.format( - operation=operation, - bucket_index=bucket_index - ) - - def _get_all_buckets_time_fields(self, operation): - ''' - Builds a list of all the fields in the metrics hash for the latency - buckets for a given operation. - :param operation: Name of the operation - :type operation: str - :return: List of field names - :rtype: list - ''' - return [ - self._get_time_field(operation, bucket) - for bucket in range(0, len(BUCKETS)) - ] - - def _build_metrics_counter_data(self, count_metrics): - ''' - Build metrics counter data in the format expected by the API from the - contents of the cache. - :param count_metrics: A dictionary of name/value counter metrics - :param count_metrics: dict - :return: A list of of counter metrics - :rtype: list - ''' - return [{'name': name, 'delta': delta} - for name, delta in iteritems(count_metrics)] - - def _build_metrics_times_data(self, time_metrics): - ''' - Build metrics times data in the format expected by the API from the - contents of the cache. - :param time_metrics: A dictionary of name/latencies time metrics - :param time_metrics: dict - :return: A list of of time metrics - :rtype: list - ''' - return [{'name': name, 'latencies': latencies} - for name, latencies in iteritems(time_metrics)] - - def _build_metrics_gauge_data(self, gauge_metrics): - ''' - Build metrics gauge data in the format expected by the API from the - contents of the cache. - :param gauge_metrics: A dictionary of name/value gauge metrics - :param gauge_metrics: dict - :return: A list of of gauge metrics - :rtype: list - ''' - return [{'name': name, 'value': value} - for name, value in iteritems(gauge_metrics)] - - def _build_metrics_from_cache_response(self, response): - ''' - Builds a dictionary with time, count and gauge metrics based on the - result of calling fetch_all_and_clear (list of name/value pairs). - Each entry in the dictionary is in the format accepted by the events - API. - :param response: Response given by the fetch_all_and_clear method - :type response: lsit - :return: Dictionary with time, count and gauge metrics - :rtype: dict - ''' - if response is None: - return {'count': [], 'gauge': []} - - count = dict() - gauge = dict() - - for field, value in zip(islice(response, 0, None, 2), islice(response, 1, None, 2)): - count_match = re.match( - RedisMetricsCache._get_count_field_re(), field - ) - if count_match is not None: - count[count_match.group('counter')] = value - continue - - gauge_match = re.match( - RedisMetricsCache._get_gauge_field_re(), field - ) - if gauge_match is not None: - gauge[gauge_match.group('gauge')] = value - continue - - return { - 'count': self._build_metrics_counter_data(count), - 'gauge': self._build_metrics_gauge_data(gauge) - } - - def increment_count(self, counter, delta=1): - ''' - ''' - return self._redis.incr(self._get_count_key(counter), delta) - - def get_latency(self, operation): - return [ - 0 if count is None else count - for count in (self._redis.get(self._get_latency_bucket_key(operation, bucket)) - for bucket in range(0, len(BUCKETS)))] - - def get_latency_bucket_counter(self, operation, bucket_index): - count = self._redis.get( - self._get_latency_bucket_key(operation, bucket_index) - ) - return int(count) if count is not None else 0 - - def set_gauge(self, gauge, value): - return self._redis.set(self._get_gauge_key(gauge), value) - - def set_latency_bucket_counter(self, operation, bucket_index, value): - self._redis.set( - self._get_latency_bucket_key(operation, bucket_index), value - ) - - def get_count(self, counter): - count = self._redis.get(self._get_count_key(counter)) - return count if count else 0 - - def set_count(self, counter, value): - return self._redis.set(self._get_count_key(counter), value) - - def increment_latency_bucket_counter(self, operation, bucket_index, delta=1): - self._redis.incr( - self._get_latency_bucket_key(operation, bucket_index), - delta - ) - - def get_gauge(self, gauge): - return self._redis.get(self._get_gauge_key(gauge)) - - -class RedisSplitParser(SplitParser): - def __init__(self, segment_cache): - ''' - A SplitParser implementation that registers the segments with the redis - segment cache implementation upon parsing an IN_SEGMENT matcher. - ''' - super(RedisSplitParser, self).__init__(None) - self._segment_cache = segment_cache - - def _parse_split(self, split, block_until_ready=False): - return RedisSplit( - split['name'], split['seed'], split['killed'], - split['defaultTreatment'], split['trafficTypeName'], - split['status'], split['changeNumber'], - segment_cache=self._segment_cache, - algo=split.get('algo'), - traffic_allocation=split.get('trafficAllocation'), - traffic_allocation_seed=split.get('trafficAllocationSeed') - ) - - def _parse_matcher_in_segment(self, partial_split, matcher, - block_until_ready=False, *args, **kwargs): - matcher_data = self._get_matcher_attribute( - 'userDefinedSegmentMatcherData', - matcher - ) - segment = RedisSplitBasedSegment( - matcher_data['segmentName'], - partial_split - ) - delegate = UserDefinedSegmentMatcher(segment) - self._segment_cache.register_segment(delegate.segment.name) - return delegate - - -class RedisSplit(Split): - def __init__(self, name, seed, killed, default_treatment, traffic_type_name, - status, change_number, conditions=None, segment_cache=None, - algo=None, traffic_allocation=None, - traffic_allocation_seed=None): - ''' - A split implementation that mantains a reference to the segment cache - so segments can be easily pickled and unpickled. - :param name: Name of the feature - :type name: unicode - :param seed: Seed - :type seed: int - :param killed: Whether the split is killed or not - :type killed: bool - :param default_treatment: Default treatment for the split - :type default_treatment: str - :param conditions: Set of conditions to test - :type conditions: list - :param segment_cache: A segment cache - :type segment_cache: SegmentCache - ''' - super(RedisSplit, self).__init__( - name, seed, killed, default_treatment, traffic_type_name, status, - change_number, conditions, algo, traffic_allocation, - traffic_allocation_seed - ) - - self._segment_cache = segment_cache - - @property - def segment_cache(self): - return self._segment_cache - - @segment_cache.setter - def segment_cache(self, segment_cache): - self._segment_cache = segment_cache - - def __getstate__(self): - old_dict = self.__dict__.copy() - del old_dict['_segment_cache'] - return old_dict - - def __setstate__(self, dict): - self.__dict__.update(dict) - self._segment_cache = None - - -class RedisSplitBasedSegment(Segment): - def __init__(self, name, split): - ''' - A Segment that uses a reference to a RedisSplit redis' instance to check - if a key is in a segment - :param name: The name of the segment - :type name: str - :param split: A RedisSplit instance - :type split: RedisSplit - ''' - super(RedisSplitBasedSegment, self).__init__(name) - self._split = split - - def contains(self, key): - return self._split.segment_cache.is_in_segment(self.name, key) - - -def get_redis(config): - ''' - Build a redis client based on the configuration. - :param config: Dictionary with the contents of the config file. - :type config: dict - :return: A redis client - ''' - if 'redisFactory' in config: - redis_factory = import_from_string( - config['redisFactory'], 'redisFactory' - ) - return redis_factory() - else: - if 'redisSentinels' in config: - return default_redis_sentinel_factory(config) - else: - return default_redis_factory(config) - - -def default_redis_factory(config): - ''' - Default redis client factory. - :param config: A dict with the Redis configuration parameters - :type config: dict - :return: A StrictRedis object using the provided config values - :rtype: StrictRedis - ''' - host = config.get('redisHost', 'localhost') - port = config.get('redisPort', 6379) - db = config.get('redisDb', 0) - password = config.get('redisPassword', None) - socket_timeout = config.get('redisSocketTimeout', None) - socket_connect_timeout = config.get('redisSocketConnectTimeout', None) - socket_keepalive = config.get('redisSocketKeepalive', None) - socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) - connection_pool = config.get('redisConnectionPool', None) - unix_socket_path = config.get('redisUnixSocketPath', None) - encoding = config.get('redisEncoding', 'utf-8') - encoding_errors = config.get('redisEncodingErrors', 'strict') - charset = config.get('redisCharset', None) - errors = config.get('redisErrors', None) - decode_responses = config.get('redisDecodeResponses', False) - retry_on_timeout = config.get('redisRetryOnTimeout', False) - ssl = config.get('redisSsl', False) - ssl_keyfile = config.get('redisSslKeyfile', None) - ssl_certfile = config.get('redisSslCertfile', None) - ssl_cert_reqs = config.get('redisSslCertReqs', None) - ssl_ca_certs = config.get('redisSslCaCerts', None) - max_connections = config.get('redisMaxConnections', None) - prefix = config.get('redisPrefix') - - redis = StrictRedis( - host=host, - port=port, - db=db, - password=password, - socket_timeout=socket_timeout, - socket_connect_timeout=socket_connect_timeout, - socket_keepalive=socket_keepalive, - socket_keepalive_options=socket_keepalive_options, - connection_pool=connection_pool, - unix_socket_path=unix_socket_path, - encoding=encoding, - encoding_errors=encoding_errors, - charset=charset, - errors=errors, - decode_responses=decode_responses, - retry_on_timeout=retry_on_timeout, - ssl=ssl, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, - ssl_ca_certs=ssl_ca_certs, - max_connections=max_connections - ) - return PrefixDecorator(redis, prefix=prefix) - - -def default_redis_sentinel_factory(config): - ''' - Default redis client factory for sentinel mode. - :param config: A dict with the Redis configuration parameters - :type config: dict - :return: A Sentinel object using the provided config values - :rtype: Sentinel - ''' - sentinels = config.get('redisSentinels') - - if (sentinels is None): - raise SentinelConfigurationException('redisSentinels must be specified.') - if (not isinstance(sentinels, list)): - raise SentinelConfigurationException('Sentinels must be an array of elements in the form of' - ' [(ip, port)].') - if (len(sentinels) == 0): - raise SentinelConfigurationException('It must be at least one sentinel.') - if not all(isinstance(s, tuple) for s in sentinels): - raise SentinelConfigurationException('Sentinels must respect the tuple structure' - '[(ip, port)].') - - master_service = config.get('redisMasterService') - - if (master_service is None): - raise SentinelConfigurationException('redisMasterService must be specified.') - - db = config.get('redisDb', 0) - password = config.get('redisPassword', None) - socket_timeout = config.get('redisSocketTimeout', None) - socket_connect_timeout = config.get('redisSocketConnectTimeout', None) - socket_keepalive = config.get('redisSocketKeepalive', None) - socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) - connection_pool = config.get('redisConnectionPool', None) - encoding = config.get('redisEncoding', 'utf-8') - encoding_errors = config.get('redisEncodingErrors', 'strict') - decode_responses = config.get('redisDecodeResponses', False) - retry_on_timeout = config.get('redisRetryOnTimeout', False) - max_connections = config.get('redisMaxConnections', None) - prefix = config.get('redisPrefix') - - sentinel = Sentinel( - sentinels, - db=db, - password=password, - socket_timeout=socket_timeout, - socket_connect_timeout=socket_connect_timeout, - socket_keepalive=socket_keepalive, - socket_keepalive_options=socket_keepalive_options, - connection_pool=connection_pool, - encoding=encoding, - encoding_errors=encoding_errors, - decode_responses=decode_responses, - retry_on_timeout=retry_on_timeout, - max_connections=max_connections - ) - - redis = sentinel.master_for(master_service) - return PrefixDecorator(redis, prefix=prefix) diff --git a/splitio/segments.py b/splitio/segments.py deleted file mode 100644 index baef0c05..00000000 --- a/splitio/segments.py +++ /dev/null @@ -1,458 +0,0 @@ -"""This module contains everything related to segments""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from concurrent.futures import ThreadPoolExecutor -from json import load -from threading import Timer, RLock -import six - - -class Segment(object): - def __init__(self, name): - """ - Basic interface of a Segment. - :param name: The name of the segment - :type name: unicode - """ - self._name = name - - @property - def name(self): - """ - :return: The name of the segment - :rtype: unicode - """ - return self._name - - def contains(self, key): - """ - Tests whether a key is in a segment - :param key: The key to test - :type key: unicode - :return: True if the key is contained by the segment, False otherwise - :rtype: boolean - """ - return False - - -class InMemorySegment(Segment): - def __init__(self, name, change_number=-1, key_set=None): - """ - An implementation of a Segment that holds keys in set in memory. - :param name: The name of the segment - :type name: str - :param change_number: The change number for the segment - :type name: int - :param key_set: Set of keys contained by the segment - :type key_set: list - """ - super(InMemorySegment, self).__init__(name) - self._change_number = change_number - self._key_set = frozenset(key_set) if key_set is not None else frozenset() - - @property - def key_set(self): - return self._key_set - - @key_set.setter - def key_set(self, key_set): - self._key_set = key_set - - @property - def change_number(self): - return self._change_number - - @change_number.setter - def change_number(self, change_number): - self._change_number = change_number - - def contains(self, key): - """ - Tests whether a key is in a segment - :param key: The key to test - :type key: unicode - :return: True if the key is contained by the segment, False otherwise - :rtype: boolean - """ - return key in self._key_set - - -class SegmentFetcher(object): - """Basic segment fetcher interface.""" - def fetch(self, name, block_until_ready=False): - """ - Fetches an empty segment - :param name: The segment name - :type name: unicode - :param block_until_ready: Whether to wait until all the data is available - :type block_until_ready: bool - :return: An empty segment - :rtype: Segment - """ - return Segment(name) - - -class SelfRefreshingSegmentFetcher(object): - def __init__(self, segment_change_fetcher, interval=60, max_workers=5): - """ - A segment fetcher that generates self refreshing segments. - :param segment_change_fetcher: A segment change fetcher implementation - :type segment_change_fetcher: SegmentChangeFetcher - :param interval: An integer or callable that'll define the refreshing interval - :type interval: int - :param max_workers: The max number of workers used to fetch segment changes - :type max_workers: int - """ - self._segment_change_fetcher = segment_change_fetcher - self._executor = ThreadPoolExecutor(max_workers=max_workers) - self._interval = interval - self._segments = dict() - self._destroyed = False - - def destroy(self): - self._destroyed = True - for _, segment in six.iteritems(self._segments): - segment.destroy() - - def fetch(self, name, block_until_ready=False): - """ - Fetch self refreshing segment - :param name: The name of the segment - :type name: unicode - :param block_until_ready: Whether to wait until all the data is available - :type block_until_ready: bool - :return: A segment for the given name - :rtype: Segment - """ - if name in self._segments: - return self._segments[name] - - segment = SelfRefreshingSegment(name, self._segment_change_fetcher, self._executor, - self._interval) - self._segments[name] = segment - - if block_until_ready: - segment.refresh_segment() - segment.start(delayed_update=True) - else: - segment.start() - - return segment - - -class SelfRefreshingSegment(InMemorySegment): - def __init__(self, name, segment_change_fetcher, executor, interval, change_number=-1, - greedy=True, key_set=None): - """ - A segment implementation that refreshes itself periodically using a ThreadPoolExecutor. - :param name: The name of the segment - :type name: str - :param segment_change_fetcher: The segment change fetcher implementation - :type segment_change_fetcher: SegmentChangeFetcher - :param executor: A ThreadPoolExecutor that'll run the refreshing process - :type executor: ThreadPoolExecutor - :param interval: An integer or callable that'll define the refreshing interval - :type interval: int - :param change_number: An integer with the initial value for the "since" API argument - :type change_number: int - :param greedy: Request all changes until they are exhausted - :type greedy: bool - :param key_set: An optional initial set of keys - :type key_set: list - """ - super(SelfRefreshingSegment, self).__init__(name, change_number=change_number, - key_set=key_set) - self._segment_change_fetcher = segment_change_fetcher - self._executor = executor - self._interval = interval - self._greedy = greedy - self._stopped = True - self._rlock = RLock() - self._logger = logging.getLogger(self.__class__.__name__) - self._destroyed = False - - @property - def stopped(self): - """ - :return: Whether the refresh process has been stopped - :rtype: bool - """ - return self._stopped - - @stopped.setter - def stopped(self, stopped): - """ - :param stopped: Whether to stop the refreshing process - :type stopped: bool - """ - self._stopped = stopped - - def destroy(self): - self._destroyed = True - - def start(self, delayed_update=False): - """Starts the self-refreshing processes of the segment. - :param delayed_update: Whether to delay the update until the interval has passed - :type delayed_update: bool - """ - if not self._stopped: - return - - self._stopped = False - - if delayed_update: - self._timer_start() - else: - self._timer_refresh() - - def refresh_segment(self): - """ - The actual segment refresh process. - """ - if self._destroyed: - return - - try: - with self._rlock: - while True: - response = self._segment_change_fetcher.fetch(self._name, - self._change_number) - - # If the response fails, and doesn't return a dict, or - # returns a dict without the 'till' attribute, abort this - # execution. - if ( - not isinstance(response, dict) - or 'till' not in response - or self._change_number >= response['till'] - ): - return - - if len(response['added']) > 0 or len(response['removed']) > 0: - self._logger.info('%s added %s', self._name, - self._summarize_changes(response['added'])) - self._logger.info('%s removed %s', self._name, - self._summarize_changes(response['removed'])) - - new_key_set = ( - (self._key_set | frozenset(response['added'])) - - frozenset(response['removed']) - ) - self._key_set = new_key_set - - self._change_number = response['till'] - - if not self._greedy: - return - except: - self._logger.error('Error refreshing segment') - self._stopped = True - - def _summarize_changes(self, changes): - """Summarize the changes received from the segment change fetcher.""" - return '[{summary}{others}]'.format( - summary=','.join(changes[:min(3, len(changes))]), - others=',... {} others'.format(3 - len(changes)) if len(changes) > 3 else '' - ) - - def _timer_start(self): - try: - if hasattr(self._interval, '__call__'): - interval = self._interval() - else: - interval = self._interval - - timer = Timer(interval, self._timer_refresh) - timer.daemon = True - timer.start() - except: - self._logger.error('Error starting timer') - self._stopped = True - - def _timer_refresh(self): - """ - Responsible for setting the periodic calls to _refresh_segment using a - Timer thread. - """ - if self._destroyed: - return - - if self._stopped: - self._logger.error('Previous fetch failed, skipping this iteration ' - 'and rescheduling segment refresh.') - self._stopped = False - self._timer_start() - return - - try: - self._executor.submit(self.refresh_segment) - self._timer_start() - except: - self._logger.error('Error refreshing timer') - self._stopped = True - - -class JSONFileSegmentFetcher(SegmentFetcher): - def __init__(self, file_name): - """ - A segment fetcher that retrieves the information from a file with the JSON response of a - segmentChanges resource. - :param file_name: The name of the file - :type file_name: str - """ - with open(file_name) as f: - self._json = load(f) - - self._added = frozenset(self._json['added']) - self._removed = frozenset(self._json['removed']) - - def fetch(self, name, block_until_ready=False): - """ - Fetch in memory segment - :param name: The name of the segment - :type name: str - :param block_until_ready: Whether to wait until all the data is available - :type block_until_ready: bool - :return: A segment for the given name - :rtype: Segment - """ - segment = InMemorySegment(name, key_set=self._added - self._removed) - return segment - - -class SegmentChangeFetcher(object): - def __init__(self): - """Fetches changes in the segment since a reference point.""" - self._logger = logging.getLogger(self.__class__.__name__) - - def fetch_from_backend(self, name, since): - """ - Fetches changes for a given segment. - :param name: The name of the segment - :type name: unicode - :param since: An integer that indicates that we want the changes that occurred AFTER this - last change number. A value less than zero implies that the client is - requesting information on this segment for the first time. - :type since: int - :return: A dictionary with the changes for the segment - :rtype: dict - """ - raise NotImplementedError() - - def build_empty_segment_change(self, name, since): - """ - Builds an "empty" segment change response. Used in case of exceptions or other unforseen - problems. - :param name: The name of the segment - :type name: unicode - :param since: "till" value of the last segment change. - :type since: int - :return: A dictionary with an empty (.e.g. no change) response - :rtype: dict - """ - return { - 'name': name, - 'since': since, - 'till': since, - 'added': [], - 'removed': [] - } - - def fetch(self, name, since): - """ - Fetch changes for a segment. If the segment does not exist, or if there were problems with - the request, the method returns an empty segment change with the latest change number set - to a value less than 0. - - If no changes have happened since the change number requested, then return an empty segment - change with the latest change number equal to the requested change number. - - This is a sample response: - - { - "name": "demo_segment", - "added": [ - "some_id_6" - ], - "removed": [ - "some_id_1", "some_id_2" - ], - "since": 1460890700905, - "till": 1460890700906 - } - - :param name: The name of the segment - :type name: unicode - :param since: An integer that indicates that we want the changes that occurred AFTER this - last change number. A value less than zero implies that the client is - requesting information on this segment for the first time. - :type since: int - :return: A dictionary with the changes - :rtype: dict - """ - if type(name).__name__ == 'bytes': - name = str(name, 'utf-8') - - try: - segment_change = self.fetch_from_backend(name, since) - except: - self._logger.error('Error fetching segment changes') - segment_change = self.build_empty_segment_change(name, since) - - return segment_change - - -class ApiSegmentChangeFetcher(SegmentChangeFetcher): - def __init__(self, api): - """ - A SegmentChangeFetcher implementation that retrieves the changes from Split.io's RESTful - SDK API. - :param api: The API client to use - :type api: SdkApi - """ - super(ApiSegmentChangeFetcher, self).__init__() - self._api = api - - def fetch_from_backend(self, name, since): - return self._api.segment_changes(name, since) - - -class CacheBasedSegmentFetcher(SegmentFetcher): - def __init__(self, segment_cache): - """ - A segment fetcher based on a segments cache - :param segment_cache: The segment cache to use - :type segment_cache: SegmentCache - """ - self._segment_cache = segment_cache - - def fetch(self, name, block_until_ready=False): - """ - Fetch cache based segment - :param name: The name of the segment - :type name: str - :param block_until_ready: Whether to wait until all the data is available - :type block_until_ready: bool - :return: A segment for the given name - :rtype: Segment - """ - segment = CacheBasedSegment(name, self._segment_cache) - return segment - - -class CacheBasedSegment(Segment): - def __init__(self, name, segment_cache): - """ - A SegmentCached based implementation of a Segment - :param name: The name of the segment - :type name: str - :param segment_cache: The segment cache backend - :type segment_cache: SegmentCache - """ - super(CacheBasedSegment, self).__init__(name) - self._segment_cache = segment_cache - - def contains(self, key): - return self._segment_cache.is_in_segment(self._name, key) diff --git a/splitio/splits.py b/splitio/splits.py deleted file mode 100644 index dad4171a..00000000 --- a/splitio/splits.py +++ /dev/null @@ -1,1197 +0,0 @@ -"""This module contains everything related to splits""" -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -import logging - -from builtins import dict -from enum import Enum -from json import load -from requests.exceptions import HTTPError -from threading import Thread, Timer, RLock -from collections import namedtuple - -from future.utils import python_2_unicode_compatible - -from splitio.matchers import CombiningMatcher, AndCombiner, AllKeysMatcher, \ - UserDefinedSegmentMatcher, WhitelistMatcher, EqualToMatcher, \ - GreaterThanOrEqualToMatcher, LessThanOrEqualToMatcher, BetweenMatcher, \ - AttributeMatcher, DataType, StartsWithMatcher, EndsWithMatcher, \ - ContainsStringMatcher, ContainsAllOfSetMatcher, ContainsAnyOfSetMatcher, \ - EqualToSetMatcher, PartOfSetMatcher, DependencyMatcher, RegexMatcher, \ - BooleanMatcher - -SplitView = namedtuple( - 'SplitView', - ['name', 'traffic_type', 'killed', 'treatments', 'change_number'] -) - - -class Status(Enum): - """Split status""" - ACTIVE = "ACTIVE" - ARCHIVED = "ARCHIVED" - - -class HashAlgorithm(Enum): - """ - Hash algorithm names - """ - LEGACY = 1 - MURMUR = 2 - - -class ConditionType(Enum): - """ - Split possible condition types - """ - WHITELIST = 'WHITELIST' - ROLLOUT = 'ROLLOUT' - - -class Split(object): - def __init__(self, name, seed, killed, default_treatment, traffic_type_name, - status, change_number, conditions=None, algo=None, - traffic_allocation=None, traffic_allocation_seed=None): - """ - A class that represents a split. It associates a feature name with a set - of matchers (responsible of telling which condition to use) and - conditions (which determines which treatment to use) - :param name: Name of the feature - :type name: unicode - :param seed: Seed - :type seed: int - :param killed: Whether the split is killed or not - :type killed: bool - :param default_treatment: Default treatment for the split - :type default_treatment: str - :param conditions: Set of conditions to test - :type conditions: list - """ - self._name = name - self._seed = seed - self._killed = killed - self._default_treatment = default_treatment - self._traffic_type_name = traffic_type_name - self._status = status - self._change_number = change_number - self._conditions = conditions if conditions is not None else [] - - if traffic_allocation is None: - self._traffic_allocation = 100 - elif traffic_allocation >= 0 and traffic_allocation <= 100: - self._traffic_allocation = traffic_allocation - else: - self._traffic_allocation = 100 - - self._traffic_allocation_seed = traffic_allocation_seed - try: - self._algo = HashAlgorithm(algo) - except ValueError: - self._algo = HashAlgorithm.LEGACY - - @property - def name(self): - return self._name - - @property - def seed(self): - return self._seed - - @property - def algo(self): - return self._algo - - @property - def killed(self): - return self._killed - - @property - def default_treatment(self): - return self._default_treatment - - @property - def traffic_type_name(self): - return self._traffic_type_name - - @property - def status(self): - return self._status - - @property - def change_number(self): - return self._change_number - - @property - def conditions(self): - return self._conditions - - @property - def traffic_allocation(self): - return self._traffic_allocation - - @property - def traffic_allocation_seed(self): - return self._traffic_allocation_seed - - @python_2_unicode_compatible - def __str__(self): - return 'name: {name}, seed: {seed}, killed: {killed}, ' \ - 'default treatment: {default_treatment}, ' \ - 'conditions: {conditions}'.format( - name=self._name, seed=self._seed, killed=self._killed, - default_treatment=self._default_treatment, - conditions=','.join(map(str, self._conditions)) - ) - - -class AllKeysSplit(Split): - def __init__(self, name, treatment): - """ - A split implementation that matches everything to a single treatment. - :param name: Name of the feature - :type name: str - :param treatment: The treatment for the feature - :type treatment: str - """ - super(AllKeysSplit, self).__init__( - name, None, False, treatment, None, None, None, - [Condition(AttributeMatcher(None, AllKeysMatcher(), False), - [Partition(treatment, 100)], - None)]) - - -class Condition(object): - def __init__(self, matcher, partitions, label, - condition_type=ConditionType.WHITELIST): - """ - A class that represents a split condition. It associates a matcher with - a set of partitions. - :param matcher: A combining matcher - :type matcher: CombiningMatcher - :param partitions: A list of partitions - :type partitions: list - """ - self._matcher = matcher - self._partitions = tuple(partitions) - self._label = label - self._confition_type = condition_type - - @property - def matcher(self): - return self._matcher - - @property - def partitions(self): - return self._partitions - - @property - def label(self): - return self._label - - @property - def condition_type(self): - return self._confition_type - - @python_2_unicode_compatible - def __str__(self): - return '{matcher} then split {partitions}'.format( - matcher=self._matcher, partitions=','.join( - '{size}:{treatment}'.format(size=partition.size, - treatment=partition.treatment) - for partition in self._partitions)) - - -class Partition(object): - def __init__(self, treatment, size): - """ - A class that represents a partition of a split condition - :param treatment: The treatment for the partition - :type treatment: str - :param size: A number between 0 a 100 - :type size: float - """ - if size < 0 or size > 100: - raise ValueError('size MUST BE between 0 and 100') - - self._treatment = treatment - self._size = size - - @property - def treatment(self): - return self._treatment - - @property - def size(self): - return self._size - - @python_2_unicode_compatible - def __str__(self): - return '{size}%:{treatment}'.format(size=self._size, - treatment=self._treatment) - - -class SplitFetcher(object): # pragma: no cover - def __init__(self): - """ - The interface for a SplitFetcher. - It provides access to Split implementations. - """ - self._logger = logging.getLogger(self.__class__.__name__) - self._destroyed = False - - @property - def change_number(self): - return None - - def fetch(self, feature): - """ - Fetches the split for a given feature - :param feature: The name of the feature - :type feature: str - :return: A split associated with the feature - :rtype: Split - """ - return None - - def fetch_all(self): - """ - Feches all splits - :return: All the know splits so far - :rtype: list - """ - return None - - def destroy(self): - self._destroyed = True - - -class InMemorySplitFetcher(SplitFetcher): - def __init__(self, splits=None): - """ - A basic implementation of a split fetcher. It's responsible for - providing access to the client to the Split representations. - :param splits: An optional dictionary of feature to split entries - :type splits: dict - """ - super(InMemorySplitFetcher, self).__init__() - self._splits = splits if splits is not None else dict() - - @property - def change_number(self): - return -1 - - def fetch(self, feature): - """ - Fetches the split for a given feature - :param feature: The name of the feature - :type feature: str - :return: A split associated with the feature - :rtype: Split - """ - if self._destroyed: - return None - - return self._splits.get(feature) - - def fetch_all(self): - """ - Feches all splits - :return: All the know splits so far - :rtype: list - """ - if self._destroyed: - return [] - else: - return list(self._splits.values()) - - -class JSONFileSplitFetcher(InMemorySplitFetcher): - def __init__(self, file_name, split_parser, splits=None): - """ - A split fetcher that gets the split information from a file with the - JSON response of a call - to the splitChanges resource. - :param file_name: Name of the file with the splitChanges response - :type file_name: str - :param split_parser: The parser used to parse the responses - :type split_parser: SplitParser - :param splits: An optional dictionary of feature to split entries - :type splits: dict - """ - super(JSONFileSplitFetcher, self).__init__(splits=splits) - - self._split_parser = split_parser - with open(file_name) as f: - self._json = load(f) - - for split_change in self._json['splits']: - parsed_split = self._split_parser.parse(split_change) - self._splits[parsed_split.name] = parsed_split - - -class SelfRefreshingSplitFetcher(InMemorySplitFetcher): - def __init__(self, split_change_fetcher, split_parser, interval=30, - greedy=True, change_number=-1, splits=None): - """ - A SplitFetcher implementation that refreshes itself periodically. - :param split_change_fetcher: The split change fetcher used to fetch - changes - :type split_change_fetcher: SplitChangeFetcher - :param split_parser: The split parser - :type split_parser: SplitParser - :param interval: An integer or callable that'll define the refreshing - interval - :type interval: int - :param greedy: Request all changes until they are exhausted - :type greedy: bool - :param change_number: An integer with the initial value for the "since" - API argument - :type change_number: int - :param splits: An optional dictionary of feature to split entries - :type splits: dict - """ - super(SelfRefreshingSplitFetcher, self).__init__(splits=splits) - - self._split_change_fetcher = split_change_fetcher - self._split_parser = split_parser - self._interval = interval - self._greedy = greedy - self._change_number = change_number - self._stopped = True - self._rlock = RLock() - - @property - def stopped(self): - """ - :return: Whether the refresh process has been stopped - :rtype: bool - """ - return self._stopped - - @stopped.setter - def stopped(self, stopped): - """ - :param stopped: Whether to stop the refreshing process - :type stopped: bool - """ - self._stopped = stopped - - @property - def change_number(self): - return self._change_number - - def destroy(self): - """ - Disables the split-refreshing task, by preventing it from being - re-scheduled. - """ - super(SelfRefreshingSplitFetcher, self).destroy() - self._split_parser.destroy() - - def start(self, delayed_update=False): - """Starts the self-refreshing processes of the splits - :param delayed_update: Whether to delay the update until the interval - has passed - :type delayed_update: bool - """ - if not self._stopped: - return - - self._stopped = False - - if delayed_update: - self._timer_start() - else: - self._timer_refresh() - - def _update_splits_from_change_fetcher_response(self, response, - block_until_ready=False): - """ - Updates the splits from the response of the split_change_fetcher - :param response: A JSON with the response of - split_change_fetcher.fetch() - :type response: dict - :param block_until_ready: Whether to block until all data is available - :param block_until_ready: bool - """ - added_features = [] - removed_features = [] - - for split_change in response['splits']: - if Status(split_change['status']) != Status.ACTIVE: - self._splits.pop(split_change['name'], None) - removed_features.append(split_change['name']) - continue - - parsed_split = self._split_parser.parse( - split_change, block_until_ready=block_until_ready - ) - if parsed_split is None: - self._logger.warning( - 'We could not parse the split definition for %s. ' - 'Removing split to be safe.', split_change['name']) - self._splits.pop(split_change['name'], None) - removed_features.append(split_change['name']) - continue - - added_features.append(split_change['name']) - self._splits[split_change['name']] = parsed_split - - if len(added_features) > 0: - self._logger.info('Updated features: %s', added_features) - - if len(removed_features) > 0: - self._logger.info('Deleted features: %s', removed_features) - - def refresh_splits(self, block_until_ready=False): - """The actual split fetcher refresh process. - :param block_until_ready: Whether to block until all data is available - :param block_until_ready: bool - """ - if self._destroyed: - return - - change_number_before = self._change_number - - try: - with self._rlock: - while True: - response = self._split_change_fetcher.fetch( - self._change_number) - - # If the response fails, and doesn't return a dict, or - # returns a dict without the 'till' attribute, abort this - # execution. - if ( - not isinstance(response, dict) - or 'till' not in response - or self._change_number >= response['till'] - ): - return - - if 'splits' in response and len(response['splits']) > 0: - self._update_splits_from_change_fetcher_response( - response, block_until_ready=block_until_ready) - self._change_number = response['till'] - - if not self._greedy: - return - except: - self._logger.info('Error refreshing splits') - self._stopped = True - finally: - self._logger.info('split fetch before: %s, after: %s', - change_number_before, - self._change_number) - - def _timer_start(self): - try: - if hasattr(self._interval, '__call__'): - interval = self._interval() - else: - interval = self._interval - - timer = Timer(interval, self._timer_refresh) - timer.daemon = True - timer.start() - except: - self._logger.error('Error refreshing timer') - self._stopped = True - - def _timer_refresh(self): - """ - Responsible for setting the periodic calls to _refresh_splits using a - Timer thread - """ - if self._destroyed: - return - - if self._stopped: - self._logger.error('Previous fetch failed, skipping this iteration ' - 'and rescheduling segment refresh.') - self._stopped = False - self._timer_start() - return - - try: - thread = Thread(target=self.refresh_splits) - thread.daemon = True - thread.start() - except: - self._logger.error('Error starting splits update thread') - - self._timer_start() - - -class SplitChangeFetcher(object): - def __init__(self): - """Fetches changes in splits since a reference point.""" - self._logger = logging.getLogger(self.__class__.__name__) - - def fetch_from_backend(self, since): - """ - Fetches changes in splits since a reference point. - :param since: An integer that indicates that we want the changes that - occurred AFTER this last change number. A value less than zero implies - that the client is requesting information for the first time. - :type since: int - :return: A dictionary with the changes for splits - :rtype: dict - """ - raise NotImplementedError() - - def build_empty_response(self, since): - """ - Builds an "empty" split change response. Used in case of exceptions or - other unforseen problems. - :param since: "till" value of the last split change. - :type since: int - :return: A dictionary with an empty (.e.g. no change) response - :rtype: dict - """ - return { - 'since': since, - 'till': since, - 'splits': [] - } - - def fetch(self, since): - """ - Fetches changes in splits since a reference point. - - This method never raises exceptions or returns None. - - The list of "splits" in the returned value contains at most one split - per name. If multiple changes occurred between the requested and last - change numbers, only the latest change is returned. - - If no changes occurred, we return an empty list of changes with the same - change number as the one requested. - - If the client is asking for split changes for the first time, only the - active partitions are returned. - - :param since: An integer that indicates that we want the changes that - occurred AFTER this last change number. A value less than zero - implies that the client is requesting information for the first - time. - :type since: int - :return: A dictionary with the changes for splits - :rtype: dict - """ - try: - split_change = self.fetch_from_backend(since) - except: - self._logger.error('Error fetching split changes') - split_change = self.build_empty_response(since) - - return split_change - - -class ApiSplitChangeFetcher(SplitChangeFetcher): - def __init__(self, api): - """ - A SplitChangeFetcher implementation that retrieves the changes from - Split.io's RESTful SDK API. - :param api: The API client to use - :type api: SdkApi - """ - super(ApiSplitChangeFetcher, self).__init__() - self._api = api - - def fetch_from_backend(self, since): - try: - return self._api.split_changes(since) - except HTTPError as e: - # We handle HTTP error here to allow for status code metrics once - # they're implemented - raise e - - -class SplitParser(object): - def __init__(self, segment_fetcher): - """ - A parser for the response of the splitChanges. - :param segment_fetcher: The segment fetcher to use with user segment - conditions - :type segment_fetcher: SegmentFetcher - """ - self._logger = logging.getLogger(self.__class__.__name__) - self._segment_fetcher = segment_fetcher - self._destroyed = False - - def parse(self, split, block_until_ready=False): - """ - Parse a "split" item of the response of the splitChanges endpoint. - - If the split is archived, this method returns None/ This method will - never raise an exception. If theres a problem with the process, it'll - return None. - :param split: A dictionary with a parsed JSON of a split item - :type split: dict - :param block_until_ready: Whether to block until all data is available - :type block_until_ready: bool - :return: A parsed split - :rtype: Split - """ - try: - return self._parse(split, block_until_ready=block_until_ready) - except: - self._logger.error('Error parsing split') - return None - - def _parse(self, split, block_until_ready=False): - """ - Parse a "split" item of the response of the splitChanges endpoint. - :param split: A dictionary with a parsed JSON of a split item - :type split: dict - :param block_until_ready: Whether to block until all data is available - :type block_until_ready: bool - :return: A parsed split - :rtype: Split - """ - if Status[split['status']] != Status.ACTIVE: - return None - - partial_split = self._parse_split( - split, block_until_ready=block_until_ready - ) - self._parse_conditions( - partial_split, split, block_until_ready=block_until_ready - ) - - return partial_split - - def _parse_split(self, split, block_until_ready=False): - """Parse split properties. - :param split: A dictionary with a parsed JSON of a split item - :type split: dict - :param block_until_ready: Whether to block until all data is available - :type block_until_ready: bool - :return: A partial parsed split - :rtype: Split - """ - return Split( - split['name'], - split['seed'], - split['killed'], - split['defaultTreatment'], - split['trafficTypeName'], - split['status'], - split['changeNumber'], - algo=split.get('algo'), - traffic_allocation=split.get('trafficAllocation'), - traffic_allocation_seed=split.get('trafficAllocationSeed') - ) - - def _parse_conditions(self, partial_split, split, block_until_ready=False): - """Parse split conditions - :param partial_split: The partially parsed split - :param partial_split: Split - :param split: A dictionary with a parsed JSON of a split item - :type split: dict - :param block_until_ready: Whether to block until all data is available - :type block_until_ready: bool - :return: - """ - for condition in split['conditions']: - parsed_partitions = [ - Partition(partition['treatment'], partition['size']) - for partition in condition['partitions'] - ] - combining_matcher = self._parse_matcher_group( - partial_split, condition['matcherGroup'], - block_until_ready=block_until_ready - ) - label = None - if 'label' in condition: - label = condition['label'] - - try: - condition_type = ConditionType(condition.get('conditionType')) - except: - condition_type = ConditionType.WHITELIST - - partial_split.conditions.append( - Condition( - combining_matcher, - parsed_partitions, - label, - condition_type - ) - ) - - def _parse_matcher_group(self, partial_split, matcher_group, - block_until_ready=False): - """ - Parses a matcher group - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher_group: A list of dictionaries with the JSON - representation of a matcher - :type matcher_group: list - :param block_until_ready: Whether to block until all data is available - :type block_until_ready: bool - :return: A combining matcher - :rtype: CombiningMatcher - """ - if ('matchers' not in matcher_group or - len(matcher_group['matchers']) == 0): - raise ValueError('Missing or empty matchers') - - delegates = [self._parse_matcher(partial_split, matcher, - block_until_ready=block_until_ready) - for matcher in matcher_group['matchers']] - combiner = self._parse_combiner(matcher_group['combiner']) - - return CombiningMatcher(combiner, delegates) - - def _get_matcher_attribute(self, attribute, matcher): - """ - Validates the presence of an attribute on a matcher dictionarry and - returns its value. - :param attribute: The name of the attribute - :type attribute: str - :param matcher: A dictionary with the JSON representation of a matcher - :type matcher: dict - :return: The value of matcher[attribute] - :rtype: object - """ - if attribute not in matcher or matcher[attribute] is None: - raise ValueError( - 'Null or missing matcher attribute {}'.format(attribute) - ) - - return matcher[attribute] - - def _get_matcher_data_data_type(self, matcher_data): - """ - Gets the data type for a matcher data dictionary - :param matcher_data: A dictionary with the JSON representation of a - matcher data - :type matcher_data: dict - :return: The data type associated with the matcher data - :rtype: DataType - """ - try: - return DataType[matcher_data.get('dataType', None)] - except KeyError: - raise ValueError('Invalid data type value: {}'.format( - matcher_data.get('dataType', None) - )) - - def _parse_combiner(self, combiner, *args, **kwargs): - """ - Parses a combiner - :param combiner: The identifier of a combiner - :type combiner: str - :return: The combiner associated with the identifier - :rtype: AndCombiner - """ - if combiner == 'AND': - return AndCombiner() - - raise ValueError('Invalid combiner type: {}'.format(combiner)) - - def _parse_matcher_all_keys(self, partial_split, matcher, *args, **kwargs): - """ - Parses an ALL_KEYS matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an - ALL_KEYS matcher - :type matcher: dict - :return: The parsed matcher - :rtype: AllKeysMatcher - """ - delegate = AllKeysMatcher() - return delegate - - def _parse_matcher_in_segment(self, partial_split, matcher, - block_until_ready=False, *args, **kwargs): - """ - Parses an IN_SEGMENT matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an - IN_SEGMENT matcher - :type matcher: dict - :return: The parsed matcher - :rtype: UserDefinedSegmentMatcher - """ - matcher_data = self._get_matcher_attribute( - 'userDefinedSegmentMatcherData', matcher - ) - segment = self._segment_fetcher.fetch( - matcher_data['segmentName'], block_until_ready=block_until_ready - ) - delegate = UserDefinedSegmentMatcher(segment) - return delegate - - def _parse_matcher_whitelist(self, partial_split, matcher, *args, **kwargs): - """ - Parses a WHITELIST matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - WHITELIST matcher - :type matcher: dict - :return: The parsed matcher - :rtype: WhitelistMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = WhitelistMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_equal_to(self, partial_split, matcher, *args, **kwargs): - """ - Parses an EQUAL_TO matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an EQUAL_TO - matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: EqualToMatcher - """ - matcher_data = self._get_matcher_attribute( - 'unaryNumericMatcherData', matcher - ) - data_type = self._get_matcher_data_data_type(matcher_data) - - delegate = EqualToMatcher.for_data_type( - data_type, matcher_data.get('value', None) - ) - return delegate - - def _parse_matcher_greater_than_or_equal_to(self, partial_split, matcher, - *args, **kwargs): - """ - Parses a GREATER_THAN_OR_EQUAL_TO matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - GREATER_THAN_OR_EQUAL_TO - matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: GreaterThanOrEqualToMatcher - """ - matcher_data = self._get_matcher_attribute( - 'unaryNumericMatcherData', matcher - ) - data_type = self._get_matcher_data_data_type(matcher_data) - - delegate = GreaterThanOrEqualToMatcher.for_data_type( - data_type, matcher_data.get('value', None) - ) - return delegate - - def _parse_matcher_less_than_or_equal_to(self, partial_split, matcher, - *args, **kwargs): - """ - Parses a LESS_THAN_OR_EQUAL_TO matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an - LESS_THAN_OR_EQUAL_TO matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: LessThanOrEqualToMatcher - """ - matcher_data = self._get_matcher_attribute( - 'unaryNumericMatcherData', matcher - ) - data_type = self._get_matcher_data_data_type(matcher_data) - - delegate = LessThanOrEqualToMatcher.for_data_type(data_type, - matcher_data['value']) - return delegate - - def _parse_matcher_starts_with(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a STARTS_WITH matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - STARTS_WITH matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = StartsWithMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_ends_with(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a ENDS_WITH matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - ENDS_WITH matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = EndsWithMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_contains_string(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a CONTAINS_STRING matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - CONTAINS_STRING matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = ContainsStringMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_contains_all_of_set(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a CONTAINS_ALL_OF_SET matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - CONTAINS_ALL_OF_SET matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = ContainsAllOfSetMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_contains_any_of_set(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a CONTAINS_ANY_OF_SET matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - CONTAINS_ANY_OF_SET matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = ContainsAnyOfSetMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_equal_to_set(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a EQUAL_TO_SET matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - EQUAL_TO_SET matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = EqualToSetMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_part_of_set(self, partial_split, matcher, *args, - **kwargs): - """ - Parses a PART_OF_SET matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a - PART_OF_SET matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'whitelistMatcherData', - matcher - ) - delegate = PartOfSetMatcher(matcher_data['whitelist']) - return delegate - - def _parse_matcher_between(self, partial_split, matcher, *args, **kwargs): - """ - Parses a BETWEEN matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an BETWEEN - matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'betweenMatcherData', matcher - ) - data_type = self._get_matcher_data_data_type(matcher_data) - - delegate = BetweenMatcher.for_data_type(data_type, - matcher_data.get('start', None), - matcher_data.get('end', None)) - return delegate - - def _parse_matcher_in_split_treatment(self, partial_split, matcher, *args, **kwargs): - """ - Parses an IN_SPLIT_TREATMENT matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an BETWEEN - matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'dependencyMatcherData', matcher - ) - - delegate = DependencyMatcher(matcher_data) - return delegate - - def _parse_matcher_equal_to_boolean(self, partial_split, matcher, *args, **kwargs): - """ - Parses an EQUAL_TO_BOOLEAN matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an BETWEEN - matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'booleanMatcherData', matcher - ) - - delegate = BooleanMatcher(matcher_data) - return delegate - - def _parse_matcher_matches_string(self, partial_split, matcher, *args, **kwargs): - """ - Parses an MATCHER_STRING matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of an BETWEEN - matcher - :type matcher: dict - :return: The parsed matcher (dependent on data type) - :rtype: BetweenMatcher - """ - matcher_data = self._get_matcher_attribute( - 'stringMatcherData', matcher - ) - - delegate = RegexMatcher(matcher_data) - return delegate - - def _parse_matcher(self, partial_split, matcher, block_until_ready=False): - """ - Parses a matcher - :param partial_split: The partially parsed split - :param partial_split: Split - :param matcher: A dictionary with the JSON representation of a matcher - :type matcher: dict - :param block_until_ready: Whether to block until all data is available - :type block_until_ready: bool - :return: A parsed attribute matcher (with a delegate dependent on type) - :rtype: AttributeMatcher - """ - if 'matcherType' not in matcher or matcher['matcherType'] is None: - raise ValueError('Missing matcher type value') - - matcher_type = matcher['matcherType'] - - try: - matcher_parse_method = getattr( - self, '_parse_matcher_{}'.format(matcher_type.strip().lower())) - delegate = matcher_parse_method(partial_split, matcher, - block_until_ready=block_until_ready) - except AttributeError: - raise ValueError('Invalid matcher type: {}'.format(matcher_type)) - - if delegate is None: - raise ValueError( - 'Unable to create matcher for matcher type: {}' - .format(matcher_type) - ) - - attribute = None - if 'keySelector' in matcher and matcher['keySelector'] and \ - 'attribute' in matcher['keySelector']: - attribute = matcher['keySelector']['attribute'] - - return AttributeMatcher( - attribute, delegate, matcher.get('negate', False) - ) - - def destroy(self): - self._segment_fetcher.destroy() - - -class CacheBasedSplitFetcher(SplitFetcher): - def __init__(self, split_cache): - """ - A cache based SplitFetcher implementation - :param split_cache: The split cache - :type split_cache: SplitCache - """ - super(CacheBasedSplitFetcher, self).__init__() - - self._split_cache = split_cache - - def fetch(self, feature): - if self._destroyed: - return None - - return self._split_cache.get_split(feature) - - def fetch_all(self): - """ - Feches all splits - :return: All the know splits so far - :rtype: list - """ - if self._destroyed: - return [] - else: - return self._split_cache.get_splits() - - @property - def change_number(self): - return self._split_cache.get_change_number() diff --git a/splitio/storage/__init__.py b/splitio/storage/__init__.py new file mode 100644 index 00000000..21204cf8 --- /dev/null +++ b/splitio/storage/__init__.py @@ -0,0 +1,293 @@ +"""Base storage interfaces.""" +from __future__ import absolute_import + +import abc + +class SplitStorage(object): + """Split storage interface implemented as an abstract class.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def get(self, split_name): + """ + Retrieve a split. + + :param split_name: Name of the feature to fetch. + :type split_name: str + + :rtype: str + """ + pass + + @abc.abstractmethod + def put(self, split): + """ + Store a split. + + :param split: Split object to store + :type split_name: splitio.models.splits.Split + """ + pass + + @abc.abstractmethod + def remove(self, split_name): + """ + Remove a split from storage. + + :param split_name: Name of the feature to remove. + :type split_name: str + + :return: True if the split was found and removed. False otherwise. + :rtype: bool + """ + pass + + @abc.abstractmethod + def get_change_number(self): + """ + Retrieve latest split change number. + + :rtype: int + """ + pass + + @abc.abstractmethod + def set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + pass + + @abc.abstractmethod + def get_split_names(self): + """ + Retrieve a list of all split names. + + :return: List of split names. + :rtype: list(str) + """ + pass + + @abc.abstractmethod + def get_all_splits(self): + """ + Return all the splits. + + :return: List of all the splits. + :rtype: list + """ + pass + + def get_segment_names(self): + """ + Return a set of all segments referenced by splits in storage. + + :return: Set of all segment names. + :rtype: set(string) + """ + return set([name for spl in self.get_all_splits() for name in spl.get_segment_names()]) + + +class SegmentStorage(object): + """Segment storage interface implemented as an abstract class.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + pass + + @abc.abstractmethod + def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + pass + + @abc.abstractmethod + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Store a split. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: List of members to add to the segment. + :type to_add: list + :param to_remove: List of members to remove from the segment. + :type to_remove: list + """ + pass + + @abc.abstractmethod + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + pass + + @abc.abstractmethod + def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + pass + + @abc.abstractmethod + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + pass + + +class ImpressionStorage(object): + """Impressions storage interface.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def put(self, impressions): + """ + Put one or more impressions in storage. + + :param impressions: List of one or more impressions to store. + :type impressions: list + """ + pass + + @abc.abstractmethod + def pop_many(self, count): + """ + Pop the oldest N impressions from storage. + + :param count: Number of impressions to pop. + :type count: int + """ + pass + + +class EventStorage(object): + """Events storage interface.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def put(self, events): + """ + Put one or more events in storage. + + :param events: List of one or more events to store. + :type events: list + """ + pass + + @abc.abstractmethod + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + pass + + +class TelemetryStorage(object): + """Telemetry storage interface.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def inc_latency(self, name, bucket): + """ + Add a latency. + + :param name: Name of the latency metric. + :type name: str + :param value: Value of the latency metric. + :tyoe value: int + """ + pass + + @abc.abstractmethod + def inc_counter(self, name): + """ + Increment a counter. + + :param name: Name of the counter metric. + :type name: str + """ + pass + + @abc.abstractmethod + def put_gauge(self, name, value): + """ + Add a gauge metric. + + :param name: Name of the gauge metric. + :type name: str + :param value: Value of the gauge metric. + :type value: int + """ + pass + + @abc.abstractmethod + def pop_counters(self): + """ + Get all the counters. + + :rtype: list + """ + pass + + @abc.abstractmethod + def pop_gauges(self): + """ + Get all the gauges. + + :rtype: list + + """ + pass + + @abc.abstractmethod + def pop_latencies(self): + """ + Get all latencies. + + :rtype: list + """ + pass diff --git a/splitio/storage/adapters/__init__.py b/splitio/storage/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py new file mode 100644 index 00000000..c9345fc6 --- /dev/null +++ b/splitio/storage/adapters/redis.py @@ -0,0 +1,410 @@ +"""Redis client wrapper with prefix support.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from builtins import str +from six import string_types, binary_type +from splitio.exceptions import SentinelConfigurationException + +try: + from redis import StrictRedis + from redis.sentinel import Sentinel + from redis.exceptions import RedisError +except ImportError: + def missing_redis_dependencies(*_, **__): + """Fail if missing dependencies are used.""" + raise NotImplementedError( + 'Missing Redis support dependencies. ' + 'Please use `pip install splitio_client[redis]` to install the sdk with redis support' + ) + StrictRedis = Sentinel = missing_redis_dependencies + + +def _bytes_to_string(maybe_bytes, encode='utf-8'): + if type(maybe_bytes).__name__ == 'bytes': + return str(maybe_bytes, encode) + return maybe_bytes + + +class RedisAdapterException(Exception): + """Exception to be thrown when a redis command fails with an exception.""" + + def __init__(self, message, original_exception=None): + """ + Exception constructor. + + :param message: Custom exception message. + :type message: str + :param original_exception: Original exception object. + :type original_exception: Exception + """ + Exception.__init__(self, message) + self._original_exception = original_exception + + @property + def original_exception(self): + """Return original exception.""" + return self._original_exception + + +class RedisAdapter(object): + """ + Instance decorator for Redis clients such as StrictRedis. + + Adds an extra layer handling addition/removal of user prefix when handling + keys + """ + + def __init__(self, decorated, prefix=None): + """ + Store the user prefix and the redis client instance. + + :param decorated: Instance of redis cache client to decorate. + :param prefix: User prefix to add. + """ + self._prefix = prefix + self._decorated = decorated + + def _add_prefix(self, k): + """ + Add a prefix to the contents of k. + + 'k' may be: + - a single key (of type string or unicode in python2, or type string + in python 3. In which case we simple add a prefix with a dot. + - a list, in which the prefix is applied to element. + If no user prefix is stored, the key/list of keys will be returned as is + + :param k: single (string) or list of (list) keys. + :returns: Key(s) with prefix if applicable + """ + if self._prefix: + if isinstance(k, string_types): + return '{prefix}.{key}'.format(prefix=self._prefix, key=k) + elif isinstance(k, list) and k: + if isinstance(k[0], binary_type): + return [ + '{prefix}.{key}'.format(prefix=self._prefix, key=key.decode("utf8")) + for key in k + ] + elif isinstance(k[0], string_types): + return [ + '{prefix}.{key}'.format(prefix=self._prefix, key=key) + for key in k + ] + else: + return k + + raise RedisAdapterException( + "Cannot append prefix correctly. Wrong type for key(s) provided" + ) + + def _remove_prefix(self, k): + """ + Remove the user prefix from a key before handling it back to the requester. + + Similar to _add_prefix, this class will handle single strings as well + as lists. If no _prefix is set, the original key/keys will be returned. + + :param k: key(s) whose prefix will be removed. + :returns: prefix-less key(s) + """ + if self._prefix: + if isinstance(k, string_types): + return k[len(self._prefix)+1:] + elif isinstance(k, list): + return [key[len(self._prefix)+1:] for key in k] + else: + return k + + raise RedisAdapterException( + "Cannot remove prefix correctly. Wrong type for key(s) provided" + ) + + # Below starts a list of methods that implement the interface of a standard + # redis client. + + def keys(self, pattern): + """Mimic original redis function but using user custom prefix.""" + try: + return _bytes_to_string(self._remove_prefix( + self._decorated.keys(self._add_prefix(pattern)) + )) + except RedisError as exc: + raise RedisAdapterException('Failed to execute keys operation', exc) + + def set(self, name, value, *args, **kwargs): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.set( + self._add_prefix(name), value, *args, **kwargs + ) + except RedisError as exc: + raise RedisAdapterException('Failed to execute set operation', exc) + + def get(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return _bytes_to_string(self._decorated.get(self._add_prefix(name))) + except RedisError as exc: + raise RedisAdapterException('Error executing get operation', exc) + + def setex(self, name, time, value): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.setex(self._add_prefix(name), time, value) + except RedisError as exc: + raise RedisAdapterException('Error executing setex operation', exc) + + def delete(self, names): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.delete(self._add_prefix(names)) + except RedisError as exc: + raise RedisAdapterException('Error executing delete operation', exc) + + def exists(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.exists(self._add_prefix(name)) + except RedisError as exc: + raise RedisAdapterException('Error executing exists operation', exc) + + def mget(self, names): + """Mimic original redis function but using user custom prefix.""" + try: + return _bytes_to_string(self._decorated.mget(self._add_prefix(names))) + except RedisError as exc: + raise RedisAdapterException('Error executing mget operation', exc) + + def smembers(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + _bytes_to_string(item) + for item in self._decorated.smembers(self._add_prefix(name)) + ] + except RedisError as exc: + raise RedisAdapterException('Error executing smembers operation', exc) + + def sadd(self, name, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.sadd(self._add_prefix(name), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing sadd operation', exc) + + def srem(self, name, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.srem(self._add_prefix(name), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing srem operation', exc) + + def sismember(self, name, value): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.sismember(self._add_prefix(name), value) + except RedisError as exc: + raise RedisAdapterException('Error executing sismember operation', exc) + + def eval(self, script, number_of_keys, *keys): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.eval(script, number_of_keys, *self._add_prefix(list(keys))) + except RedisError as exc: + raise RedisAdapterException('Error executing eval operation', exc) + + def hset(self, name, key, value): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.hset(self._add_prefix(name), key, value) + except RedisError as exc: + raise RedisAdapterException('Error executing hset operation', exc) + + def hget(self, name, key): + """Mimic original redis function but using user custom prefix.""" + try: + return _bytes_to_string(self._decorated.hget(self._add_prefix(name), key)) + except RedisError as exc: + raise RedisAdapterException('Error executing hget operation', exc) + + def incr(self, name, amount=1): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.incr(self._add_prefix(name), amount) + except RedisError as exc: + raise RedisAdapterException('Error executing incr operation', exc) + + def getset(self, name, value): + """Mimic original redis function but using user custom prefix.""" + try: + return _bytes_to_string(self._decorated.getset(self._add_prefix(name), value)) + except RedisError as exc: + raise RedisAdapterException('Error executing getset operation', exc) + + def rpush(self, key, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.rpush(self._add_prefix(key), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing rpush operation', exc) + + def expire(self, key, value): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.expire(self._add_prefix(key), value) + except RedisError as exc: + raise RedisAdapterException('Error executing expire operation', exc) + + def rpop(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return _bytes_to_string(self._decorated.rpop(self._add_prefix(key))) + except RedisError as exc: + raise RedisAdapterException('Error executing rpop operation', exc) + + def ttl(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.ttl(self._add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing ttl operation', exc) + + +def _build_default_client(config): #pylint: disable=too-many-locals + """ + Build a redis adapter. + + :param config: Redis configuration properties + :type config: dict + + :return: A wrapped StrictRedis object + :rtype: splitio.storage.adapters.redis.RedisAdapter + """ + host = config.get('redisHost', 'localhost') + port = config.get('redisPort', 6379) + database = config.get('redisDb', 0) + password = config.get('redisPassword', None) + socket_timeout = config.get('redisSocketTimeout', None) + socket_connect_timeout = config.get('redisSocketConnectTimeout', None) + socket_keepalive = config.get('redisSocketKeepalive', None) + socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) + connection_pool = config.get('redisConnectionPool', None) + unix_socket_path = config.get('redisUnixSocketPath', None) + encoding = config.get('redisEncoding', 'utf-8') + encoding_errors = config.get('redisEncodingErrors', 'strict') + charset = config.get('redisCharset', None) + errors = config.get('redisErrors', None) + decode_responses = config.get('redisDecodeResponses', False) + retry_on_timeout = config.get('redisRetryOnTimeout', False) + ssl = config.get('redisSsl', False) + ssl_keyfile = config.get('redisSslKeyfile', None) + ssl_certfile = config.get('redisSslCertfile', None) + ssl_cert_reqs = config.get('redisSslCertReqs', None) + ssl_ca_certs = config.get('redisSslCaCerts', None) + max_connections = config.get('redisMaxConnections', None) + prefix = config.get('redisPrefix') + + redis = StrictRedis( + host=host, + port=port, + db=database, + password=password, + socket_timeout=socket_timeout, + socket_connect_timeout=socket_connect_timeout, + socket_keepalive=socket_keepalive, + socket_keepalive_options=socket_keepalive_options, + connection_pool=connection_pool, + unix_socket_path=unix_socket_path, + encoding=encoding, + encoding_errors=encoding_errors, + charset=charset, + errors=errors, + decode_responses=decode_responses, + retry_on_timeout=retry_on_timeout, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + max_connections=max_connections + ) + return RedisAdapter(redis, prefix=prefix) + + +def _build_sentinel_client(config): #pylint: disable=too-many-locals + """ + Build a redis client with sentinel replication. + + :param config: Redis configuration properties. + :type config: dict + + :return: A Wrapped redis-sentinel client + :rtype: splitio.storage.adapters.redis.RedisAdapter + """ + sentinels = config.get('redisSentinels') + + if sentinels is None: + raise SentinelConfigurationException('redisSentinels must be specified.') + if not isinstance(sentinels, list): + raise SentinelConfigurationException('Sentinels must be an array of elements in the form of' + ' [(ip, port)].') + if not sentinels: + raise SentinelConfigurationException('It must be at least one sentinel.') + if not all(isinstance(s, tuple) for s in sentinels): + raise SentinelConfigurationException('Sentinels must respect the tuple structure' + '[(ip, port)].') + + master_service = config.get('redisMasterService') + + if master_service is None: + raise SentinelConfigurationException('redisMasterService must be specified.') + + database = config.get('redisDb', 0) + password = config.get('redisPassword', None) + socket_timeout = config.get('redisSocketTimeout', None) + socket_connect_timeout = config.get('redisSocketConnectTimeout', None) + socket_keepalive = config.get('redisSocketKeepalive', None) + socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) + connection_pool = config.get('redisConnectionPool', None) + encoding = config.get('redisEncoding', 'utf-8') + encoding_errors = config.get('redisEncodingErrors', 'strict') + decode_responses = config.get('redisDecodeResponses', False) + retry_on_timeout = config.get('redisRetryOnTimeout', False) + max_connections = config.get('redisMaxConnections', None) + prefix = config.get('redisPrefix') + + sentinel = Sentinel( + sentinels, + db=database, + password=password, + socket_timeout=socket_timeout, + socket_connect_timeout=socket_connect_timeout, + socket_keepalive=socket_keepalive, + socket_keepalive_options=socket_keepalive_options, + connection_pool=connection_pool, + encoding=encoding, + encoding_errors=encoding_errors, + decode_responses=decode_responses, + retry_on_timeout=retry_on_timeout, + max_connections=max_connections + ) + + redis = sentinel.master_for(master_service) + return RedisAdapter(redis, prefix=prefix) + + +def build(config): + """ + Build a redis storage according to the configuration received. + + :param config: SDK Configuration parameters with redis properties. + :type config: dict. + + :return: A redis client. + :rtype: splitio.storage.adapters.redis.RedisAdapter + """ + if 'redisSentinels' in config: + return _build_sentinel_client(config) + return _build_default_client(config) diff --git a/splitio/storage/adapters/uwsgi_cache.py b/splitio/storage/adapters/uwsgi_cache.py new file mode 100644 index 00000000..f6b908b6 --- /dev/null +++ b/splitio/storage/adapters/uwsgi_cache.py @@ -0,0 +1,134 @@ +"""UWSGI Cache Storage adapter module.""" + +from __future__ import absolute_import, division, print_function, unicode_literals +import time + +try: + #uwsgi is loaded at runtime by uwsgi app. + import uwsgi +except ImportError: + def missing_uwsgi_dependencies(*args, **kwargs): #pylint: disable=unused-argument + """Only complain for missing deps if they're used.""" + raise NotImplementedError('Missing uWSGI support dependencies.') + uwsgi = missing_uwsgi_dependencies + +# Cache used for locking & signaling keys +_SPLITIO_LOCK_CACHE_NAMESPACE = 'splitio_locks' + +# Cache where split definitions are stored +_SPLITIO_SPLITS_CACHE_NAMESPACE = 'splitio_splits' + +# Cache where segments are stored +_SPLITIO_SEGMENTS_CACHE_NAMESPACE = 'splitio_segments' + +# Cache where impressions are stored +_SPLITIO_IMPRESSIONS_CACHE_NAMESPACE = 'splitio_impressions' + +# Cache where metrics are stored +_SPLITIO_METRICS_CACHE_NAMESPACE = 'splitio_metrics' + +# Cache where events are stored (1 key with lots of blocks) +_SPLITIO_EVENTS_CACHE_NAMESPACE = 'splitio_events' + +# Cache where changeNumbers are stored +_SPLITIO_CHANGE_NUMBERS = 'splitio_changeNumbers' + +# Cache with a big block size used for lists +_SPLITIO_MISC_NAMESPACE = 'splitio_misc' + + +class UWSGILock(object): + """Context manager to be used for locking a key in the cache.""" + + def __init__(self, adapter, key, overwrite_lock_seconds=5): + """ + Initialize a lock with the key `key` and waits up to `overwrite_lock_seconds` to release. + + :param key: Key to be used. + :type key: str + + :param overwrite_lock_seconds: How many seconds to wait before force-releasing. + :type overwrite_lock_seconds: int + """ + self._key = key + self._overwrite_lock_seconds = overwrite_lock_seconds + self._uwsgi = adapter + + def __enter__(self): + """Loop until the lock is manually released or timeout occurs.""" + initial_time = time.time() + while True: + if not self._uwsgi.cache_exists(self._key, _SPLITIO_LOCK_CACHE_NAMESPACE): + self._uwsgi.cache_set(self._key, str('locked'), 0, _SPLITIO_LOCK_CACHE_NAMESPACE) + return + else: + if time.time() - initial_time > self._overwrite_lock_seconds: + return + time.sleep(0.1) + + def __exit__(self, *args): + """Remove lock.""" + self._uwsgi.cache_del(self._key, _SPLITIO_LOCK_CACHE_NAMESPACE) + + +class UWSGICacheEmulator(object): + """UWSGI mock.""" + + def __init__(self): + """ + UWSGI Cache Emulator for unit tests. Implements uwsgi cache framework interface. + + http://uwsgi-docs.readthedocs.io/en/latest/Caching.html#accessing-the-cache-from-your-applications-using-the-cache-api + """ + self._cache = dict() + + @staticmethod + def _check_string_data_type(value): + if type(value).__name__ == 'str': + return True + raise TypeError( + 'The value to add into uWSGI cache must be string and %s given' % type(value).__name__ + ) + + def cache_get(self, key, cache_namespace='default'): + """Get an element from cache.""" + if self.cache_exists(key, cache_namespace): + return self._cache[cache_namespace][key] + return None + + def cache_set(self, key, value, expires=0, cache_namespace='default'): #pylint: disable=unused-argument + """Set an elemen in the cache.""" + self._check_string_data_type(value) + + if cache_namespace in self._cache: + self._cache[cache_namespace][key] = value + else: + self._cache[cache_namespace] = {key:value} + + def cache_update(self, key, value, expires=0, cache_namespace='default'): + """Update an element.""" + self.cache_set(key, value, expires, cache_namespace) + + def cache_exists(self, key, cache_namespace='default'): + """Return whether the element exists.""" + if cache_namespace in self._cache: + if key in self._cache[cache_namespace]: + return True + return False + + def cache_del(self, key, cache_namespace='default'): + """Delete an item from the cache.""" + if cache_namespace in self._cache: + self._cache[cache_namespace].pop(key, None) + + def cache_clear(self, cache_namespace='default'): + """Delete all elements in cache.""" + self._cache.pop(cache_namespace, None) + + +def get_uwsgi(emulator=False): + """Return a uwsgi imported module or an emulator to use in unit test.""" + if emulator: + return UWSGICacheEmulator() + + return uwsgi diff --git a/splitio/storage/inmemmory.py b/splitio/storage/inmemmory.py new file mode 100644 index 00000000..655acee0 --- /dev/null +++ b/splitio/storage/inmemmory.py @@ -0,0 +1,395 @@ +"""In memory storage classes.""" +from __future__ import absolute_import + +import logging +import threading +from six.moves import queue +from splitio.models.segments import Segment +from splitio.storage import SplitStorage, SegmentStorage, ImpressionStorage, EventStorage, \ + TelemetryStorage + + +class InMemorySplitStorage(SplitStorage): + """InMemory implementation of a split storage.""" + + def __init__(self): + """Constructor.""" + self._lock = threading.RLock() + self._splits = {} + self._change_number = -1 + + def get(self, split_name): + """ + Retrieve a split. + + :param split_name: Name of the feature to fetch. + :type split_name: str + + :rtype: splitio.models.splits.Split + """ + with self._lock: + return self._splits.get(split_name) + + def put(self, split): + """ + Store a split. + + :param split: Split object. + :type split: splitio.models.split.Split + """ + with self._lock: + self._splits[split.name] = split + + def remove(self, split_name): + """ + Remove a split from storage. + + :param split_name: Name of the feature to remove. + :type split_name: str + + :return: True if the split was found and removed. False otherwise. + :rtype: bool + """ + with self._lock: + try: + self._splits.pop(split_name) + return True + except KeyError: + return False + + def get_change_number(self): + """ + Retrieve latest split change number. + + :rtype: int + """ + with self._lock: + return self._change_number + + def set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + with self._lock: + self._change_number = new_change_number + + def get_split_names(self): + """ + Retrieve a list of all split names. + + :return: List of split names. + :rtype: list(str) + """ + with self._lock: + return list(self._splits.keys()) + + def get_all_splits(self): + """ + Return all the splits. + + :return: List of all the splits. + :rtype: list + """ + with self._lock: + return list(self._splits.values()) + + +class InMemorySegmentStorage(SegmentStorage): + """In-memory implementation of a segment storage.""" + + def __init__(self): + """Constructor.""" + self._segments = {} + self._change_numbers = {} + self._lock = threading.RLock() + + def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :rtype: str + """ + with self._lock: + return self._segments.get(segment_name) + + def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + with self._lock: + self._segments[segment.name] = segment + + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Update a split. Create it if it doesn't exist. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: Set of members to add to the segment. + :type to_add: set + :param to_remove: List of members to remove from the segment. + :type to_remove: Set + """ + with self._lock: + if not segment_name in self._segments: + self._segments[segment_name] = Segment(segment_name, to_add, change_number) + return + + self._segments[segment_name].update(to_add, to_remove) + if change_number is not None: + self._segments[segment_name].change_number = change_number + + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + with self._lock: + if not segment_name in self._segments: + return None + return self._segments[segment_name].change_number + + def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + with self._lock: + if not segment_name in self._segments: + return + self._segments[segment_name].change_number = new_change_number + + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + with self._lock: + return segment_name in self._segments and self._segments[segment_name].contains(key) + + +class InMemoryImpressionStorage(ImpressionStorage): + """In memory implementation of an impressions storage.""" + + def __init__(self, queue_size): + """ + Construct an instance. + + :param eventsQueueSize: How many events to queue before forcing a submission + """ + self._impressions = queue.Queue(maxsize=queue_size) + self._lock = threading.Lock() + self._queue_full_hook = None + + def set_queue_full_hook(self, hook): + """ + Set a hook to be called when the queue is full. + + :param h: Hook to be called when the queue is full + """ + if callable(hook): + self._queue_full_hook = hook + + def put(self, impressions): + """ + Put one or more impressions in storage. + + :param impressions: List of one or more impressions to store. + :type impressions: list + """ + try: + with self._lock: + for impression in impressions: + self._impressions.put(impression, False) + return True + except queue.Full: + if self._queue_full_hook is not None and callable(self._queue_full_hook): + self._queue_full_hook() + return False + + def pop_many(self, count): + """ + Pop the oldest N impressions from storage. + + :param count: Number of impressions to pop. + :type count: int + """ + impressions = [] + with self._lock: + while not self._impressions.empty() and count > 0: + impressions.append(self._impressions.get(False)) + count -= 1 + return impressions + + +class InMemoryEventStorage(EventStorage): + """ + In memory storage for events. + + Supports adding and popping events. + """ + + def __init__(self, eventsQueueSize): + """ + Construct an instance. + + :param eventsQueueSize: How many events to queue before forcing a submission + """ + self._lock = threading.Lock() + self._events = queue.Queue(maxsize=eventsQueueSize) + self._queue_full_hook = None + + def set_queue_full_hook(self, hook): + """ + Set a hook to be called when the queue is full. + + :param h: Hook to be called when the queue is full + """ + if callable(hook): + self._queue_full_hook = hook + + def put(self, events): + """ + Add an avent to storage. + + :param event: Event to be added in the storage + """ + try: + with self._lock: + for event in events: + self._events.put(event, False) + return True + except queue.Full: + if self._queue_full_hook is not None and callable(self._queue_full_hook): + self._queue_full_hook() + return False + + def pop_many(self, count): + """ + Pop multiple items from the storage. + + :param count: number of items to be retrieved and removed from the queue. + """ + events = [] + with self._lock: + while not self._events.empty() and count > 0: + events.append(self._events.get(False)) + count -= 1 + return events + + +class InMemoryTelemetryStorage(TelemetryStorage): + """In-Memory implementation of telemetry storage interface.""" + + def __init__(self): + """Constructor.""" + self._logger = logging.getLogger(self.__class__.__name__) + self._latencies = {} + self._gauges = {} + self._counters = {} + self._latencies_lock = threading.Lock() + self._gauges_lock = threading.Lock() + self._counters_lock = threading.Lock() + + def inc_latency(self, name, bucket): + """ + Add a latency. + + :param name: Name of the latency metric. + :type name: str + :param value: Value of the latency metric. + :tyoe value: int + """ + if not 0 <= bucket <= 21: + self._logger.error('Incorect bucket "%d" for latency "%s". Ignoring.', bucket, name) + return + + with self._latencies_lock: + latencies = self._latencies.get(name, [0] * 22) + latencies[bucket] += 1 + self._latencies[name] = latencies + + def inc_counter(self, name): + """ + Increment a counter. + + :param name: Name of the counter metric. + :type name: str + """ + with self._counters_lock: + counter = self._counters.get(name, 0) + counter += 1 + self._counters[name] = counter + + def put_gauge(self, name, value): + """ + Add a gauge metric. + + :param name: Name of the gauge metric. + :type name: str + :param value: Value of the gauge metric. + :type value: int + """ + with self._gauges_lock: + self._gauges[name] = value + + def pop_counters(self): + """ + Get all the counters. + + :rtype: list + """ + with self._counters_lock: + try: + return self._counters + finally: + self._counters = {} + + def pop_gauges(self): + """ + Get all the gauges. + + :rtype: list + + """ + with self._gauges_lock: + try: + return self._gauges + finally: + self._gauges = {} + + def pop_latencies(self): + """ + Get all latencies. + + :rtype: list + """ + with self._latencies_lock: + try: + return self._latencies + finally: + self._latencies = {} diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py new file mode 100644 index 00000000..59ce3d39 --- /dev/null +++ b/splitio/storage/redis.py @@ -0,0 +1,535 @@ +"""Redis storage module.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import json +import logging + +from splitio.models.impressions import Impression +from splitio.models import splits, segments +from splitio.storage import SplitStorage, SegmentStorage, ImpressionStorage, EventStorage +from splitio.storage.adapters.redis import RedisAdapterException + + +class RedisSplitStorage(SplitStorage): + """Redis-based storage for splits.""" + + _SPLIT_KEY = 'SPLITIO.split.{split_name}' + _SPLIT_TILL_KEY = 'SPLITIO.splits.till' + + def __init__(self, redis_client): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._redis = redis_client + + def _get_key(self, split_name): + """ + Use the provided split_name to build the appropriate redis key. + + :param split_name: Name of the split to interact with in redis. + :type split_name: str + + :return: Redis key. + :rtype: str. + """ + return self._SPLIT_KEY.format(split_name=split_name) + + def get(self, split_name): + """ + Retrieve a split. + + :param split_name: Name of the feature to fetch. + :type split_name: str + + :return: A split object parsed from redis if the key exists. None otherwise + :rtype: splitio.models.splits.Split + """ + try: + raw = self._redis.get(self._get_key(split_name)) + return splits.from_raw(json.loads(raw)) if raw is not None else None + except RedisAdapterException: + self._logger.error('Error fetching split from storage', exc_info=True) + return None + + def put(self, split): + """ + Store a split. + + :param split: Split object to store + :type split_name: splitio.models.splits.Split + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def remove(self, split_name): + """ + Remove a split from storage. + + :param split_name: Name of the feature to remove. + :type split_name: str + + :return: True if the split was found and removed. False otherwise. + :rtype: bool + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def get_change_number(self): + """ + Retrieve latest split change number. + + :rtype: int + """ + try: + return self._redis.get(self._SPLIT_TILL_KEY) + except RedisAdapterException: + self._logger.exception('Error fetching split change number from storage') + return None + + def set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def get_split_names(self): + """ + Retrieve a list of all split names. + + :return: List of split names. + :rtype: list(str) + """ + try: + keys = self._redis.keys(self._get_key('*')) + return [key.replace(self._get_key(''), '') for key in keys] + except RedisAdapterException: + self._logger.exception('Error fetching change number from redis.') + return [] + + def get_all_splits(self): + """ + Return all the splits in cache. + + :return: List of all splits in cache. + :rtype: list(splitio.models.splits.Split) + """ + try: + keys = self._redis.keys(self._get_key('*')) + return [ + splits.from_raw(raw_split) + for raw_split in self._redis.mget(keys) + if raw_split is not None + ] + except RedisAdapterException: + self._logger.exception('Error when fetching all splits from redis.') + return [] + + + +class RedisSegmentStorage(SegmentStorage): + """Redis based segment storage class.""" + + _SEGMENTS_KEY = 'SPLITIO.segment.{segment_name}' + _SEGMENTS_TILL_KEY = 'SPLITIO.segment.{segment_name}.till' + + def __init__(self, redis_client): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._redis = redis_client + self._logger = logging.getLogger(self.__class__.__name__) + + def _get_till_key(self, segment_name): + """ + Use the provided segment_name to build the appropriate redis key. + + :param segment_name: Name of the segment to interact with in redis. + :type segment_name: str + + :return: Redis key. + :rtype: str. + """ + return self._SEGMENTS_TILL_KEY.format(segment_name=segment_name) + + def _get_key(self, segment_name): + """ + Use the provided segment_name to build the appropriate redis key. + + :param segment_name: Name of the segment to interact with in redis. + :type segment_name: str + + :return: Redis key. + :rtype: str. + """ + return self._SEGMENTS_KEY.format(segment_name=segment_name) + + def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :return: Segment object is key exists. None otherwise. + :rtype: splitio.models.segments.Segment + """ + try: + keys = (self._redis.smembers(self._get_key(segment_name))) + till = self._redis.get(self._get_till_key(segment_name)) + if keys is None or till is None: + return None + return segments.Segment(segment_name, keys, till) + except RedisAdapterException: + self._logger.exception('Error fetching segment from redis.') + return None + + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Store a split. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: List of members to add to the segment. + :type to_add: list + :param to_remove: List of members to remove from the segment. + :type to_remove: list + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + try: + return self._redis.get(self._get_till_key(segment_name)) + except RedisAdapterException: + self._logger.exception('Unable to fetch segment change number from redis.') + return None + + def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def put(self, segment): + """ + Store a segment. + + :param segment: Segment to store. + :type segment: splitio.models.segment.Segment + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + try: + return self._redis.sismember(self._get_key(segment_name), key) + except RedisAdapterException: + self._logger.exception('Unable to test segment members in redis.') + return False + + +class RedisImpressionsStorage(ImpressionStorage): + """Redis based event storage class.""" + + IMPRESSIONS_QUEUE_KEY = 'SPLITIO.impressions' + IMPRESSIONS_KEY_DEFAULT_TTL = 3600 + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: dict + """ + self._redis = redis_client + self._sdk_metadata = sdk_metadata + self._logger = logging.getLogger(self.__class__.__name__) + + def put(self, impressions): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + bulk_impressions = [] + for impression in impressions: + if isinstance(impression, Impression): + to_store = { + 'm': { # METADATA PORTION + 's': self._sdk_metadata['sdk-language-version'], + 'n': self._sdk_metadata['instance-id'], + 'i': self._sdk_metadata['ip-address'], + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + } + } + bulk_impressions.append(json.dumps(to_store)) + try: + inserted = self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, bulk_impressions) + if inserted == len(bulk_impressions): + self._logger.debug("SET EXPIRE KEY FOR QUEUE") + self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) + return True + except RedisAdapterException: + self._logger.exception('Something went wrong when trying to add impression to redis') + return False + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + +class RedisEventsStorage(EventStorage): + """Redis based event storage class.""" + + _KEY_TEMPLATE = 'SPLITIO.events' + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: dict + """ + self._redis = redis_client + self._sdk_metadata = sdk_metadata + self._logger = logging.getLogger(self.__class__.__name__) + + def put(self, events): + """ + Add an event to the redis storage. + + :param event: Event to add to the queue. + :type event: splitio.models.events.Event + + :return: Whether the event has been added or not. + :rtype: bool + """ + key = self._KEY_TEMPLATE + to_store = [ + json.dumps({ + 'i': { + 'key': event.key, + 'trafficTypeName': event.traffic_type_name, + 'eventTypeId': event.event_type_id, + 'value': event.value, + 'timestamp': event.timestamp + }, + 'm': { + 's': self._sdk_metadata['sdk-language-version'], + 'n': self._sdk_metadata['instance-id'], + 'i': self._sdk_metadata['ip-address'], + } + }) + for event in events + ] + try: + self._redis.rpush(key, to_store) + return True + except RedisAdapterException: + self._logger.exception('Something went wrong when trying to add event to redis') + return False + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + +class RedisTelemetryStorage(object): + """Redis-based Telemetry storage.""" + + _LATENCY_KEY_TEMPLATE = "SPLITIO/{sdk}/{instance}/latency.{name}.bucket.{bucket}" + _COUNTER_KEY_TEMPLATE = "SPLITIO/{sdk}/{instance}/count.{name}" + _GAUGE_KEY_TEMPLATE = "SPLITIO/{sdk}/{instance}/gauge.{name}" + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: dict + """ + self._redis = redis_client + self._metadata = sdk_metadata + self._logger = logging.getLogger(self.__class__.__name__) + + def _get_latency_key(self, name, bucket): + """ + Instantiate and return the latency key template. + + :param name: Name of the latency metric. + :type name: str + :param bucket: Number of bucket. + :type bucket: int + + :return: Redis latency key. + :rtype: str + """ + return self._LATENCY_KEY_TEMPLATE.format( + sdk=self._metadata['sdk-language-version'], + instance=self._metadata['instance-id'], + name=name, + bucket=bucket + ) + + def _get_counter_key(self, name): + """ + Instantiate and return the counter key template. + + :param name: Name of the counter metric. + :type name: str + + :return: Redis counter key. + :rtype: str + """ + return self._COUNTER_KEY_TEMPLATE.format( + sdk=self._metadata['sdk-language-version'], + instance=self._metadata['instance-id'], + name=name + ) + + def _get_gauge_key(self, name): + """ + Instantiate and return the latency key template. + + :param name: Name of the latency metric. + :type name: str + + :return: Redis latency key. + :rtype: str + """ + return self._GAUGE_KEY_TEMPLATE.format( + sdk=self._metadata['sdk-language-version'], + instance=self._metadata['instance-id'], + name=name, + ) + + def inc_latency(self, name, bucket): + """ + Add a latency. + + :param name: Name of the latency metric. + :type name: str + :param value: Value of the latency metric. + :tyoe value: int + """ + if not 0 <= bucket <= 21: + self._logger.error('Incorect bucket "%d" for latency "%s". Ignoring.', bucket, name) + return + + key = self._get_latency_key(name, bucket) + try: + self._redis.incr(key) + except RedisAdapterException: + self._logger.error("Error recording latency for metric \"%s\"", name) + + def inc_counter(self, name): + """ + Increment a counter. + + :param name: Name of the counter metric. + :type name: str + """ + key = self._get_counter_key(name) + try: + self._redis.incr(key) + except RedisAdapterException: + self._logger.error("Error recording counter for metric \"%s\"", name) + + def put_gauge(self, name, value): + """ + Add a gauge metric. + + :param name: Name of the gauge metric. + :type name: str + :param value: Value of the gauge metric. + :type value: int + """ + key = self._get_gauge_key(name) + try: + self._redis.set(key, value) + except RedisAdapterException: + self._logger.error("Error recording gauge for metric \"%s\"", name) + + def pop_counters(self): + """ + Get all the counters. + + :rtype: list + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def pop_gauges(self): + """ + Get all the gauges. + + :rtype: list + + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def pop_latencies(self): + """ + Get all latencies. + + :rtype: list + """ + raise NotImplementedError('Only redis-consumer mode is supported.') diff --git a/splitio/storage/uwsgi.py b/splitio/storage/uwsgi.py new file mode 100644 index 00000000..77c1da67 --- /dev/null +++ b/splitio/storage/uwsgi.py @@ -0,0 +1,583 @@ +"""UWSGI Cache based storages implementation module.""" +import logging +import json + +from splitio.storage import SplitStorage, SegmentStorage, ImpressionStorage, EventStorage, \ + TelemetryStorage +from splitio.models import splits, segments +from splitio.models.impressions import Impression +from splitio.models.events import Event +from splitio.storage.adapters.uwsgi_cache import _SPLITIO_CHANGE_NUMBERS, \ + _SPLITIO_EVENTS_CACHE_NAMESPACE, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE, \ + _SPLITIO_METRICS_CACHE_NAMESPACE, _SPLITIO_MISC_NAMESPACE, UWSGILock, \ + _SPLITIO_SEGMENTS_CACHE_NAMESPACE, _SPLITIO_SPLITS_CACHE_NAMESPACE, \ + _SPLITIO_LOCK_CACHE_NAMESPACE + + +class UWSGISplitStorage(SplitStorage): + """UWSGI-Cache based implementation of a split storage.""" + + _KEY_TEMPLATE = 'split.{suffix}' + _KEY_TILL = 'splits.till' + _KEY_FEATURE_LIST_LOCK = 'splits.list.lock' + _KEY_FEATURE_LIST = 'splits.list' + _OVERWRITE_LOCK_SECONDS = 5 + + def __init__(self, uwsgi_entrypoint): + """ + Class constructor. + + :param uwsgi_entrypoint: UWSGI module. Can be the actual module or a mock. + :type uwsgi_entrypoint: module + """ + self._uwsgi = uwsgi_entrypoint + + def get(self, split_name): + """ + Retrieve a split. + + :param split_name: Name of the feature to fetch. + :type split_name: str + + :rtype: str + """ + raw = self._uwsgi.cache_get( + self._KEY_TEMPLATE.format(suffix=split_name), + _SPLITIO_SPLITS_CACHE_NAMESPACE + ) + return splits.from_raw(json.loads(raw)) if raw is not None else None + + def put(self, split): + """ + Store a split. + + :param split: Split object to store + :type split: splitio.models.splits.Split + """ + self._uwsgi.cache_update( + self._KEY_TEMPLATE.format(suffix=split.name), + json.dumps(split.to_json()), + 0, + _SPLITIO_SPLITS_CACHE_NAMESPACE + ) + + with UWSGILock(self._uwsgi, self._KEY_FEATURE_LIST_LOCK): + try: + current = set(json.loads( + self._uwsgi.cache_get(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE) + )) + except TypeError: + current = set() + current.add(split.name) + self._uwsgi.cache_update( + self._KEY_FEATURE_LIST, + json.dumps(list(current)), + 0, + _SPLITIO_MISC_NAMESPACE + ) + + def remove(self, split_name): + """ + Remove a split from storage. + + :param split_name: Name of the feature to remove. + :type split_name: str + + :return: True if the split was found and removed. False otherwise. + :rtype: bool + """ + with UWSGILock(self._uwsgi, self._KEY_FEATURE_LIST_LOCK): + try: + current = set(json.loads( + self._uwsgi.cache_get(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE) + )) + current.remove(split_name) + self._uwsgi.cache_update( + self._KEY_FEATURE_LIST, + json.dumps(list(current)), + 0, + _SPLITIO_MISC_NAMESPACE + ) + except TypeError: + # Split list not found, no need to delete anything + pass + + return self._uwsgi.cache_del( + self._KEY_TEMPLATE.format(suffix=split_name), + _SPLITIO_SPLITS_CACHE_NAMESPACE + ) + + def get_change_number(self): + """ + Retrieve latest split change number. + + :rtype: int + """ + try: + return json.loads(self._uwsgi.cache_get(self._KEY_TILL, _SPLITIO_CHANGE_NUMBERS)) + except TypeError: + return None + + def set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + self._uwsgi.cache_update(self._KEY_TILL, str(new_change_number), 0, _SPLITIO_CHANGE_NUMBERS) + + def get_split_names(self): + """ + Return a list of all the split names. + + :return: List of split names in cache. + :rtype: list(str) + """ + if self._uwsgi.cache_exists(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE): + try: + return json.loads( + self._uwsgi.cache_get(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE) + ) + except TypeError: # Thrown by json.loads when passing none + pass # Fall back to default return statement (empty list) + return [] + + def get_all_splits(self): + """ + Return a list of all splits in cache. + + :return: List of splits. + :rtype: list(splitio.models.splits.Split) + """ + return [self.get(split_name) for split_name in self.get_split_names()] + + +class UWSGISegmentStorage(SegmentStorage): + """UWSGI-Cache based implementation of a split storage.""" + + _KEY_TEMPLATE = 'segments.{suffix}' + _SEGMENT_DATA_KEY_TEMPLATE = 'segmentData.{segment_name}' + _SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE = 'segment.{segment_name}.till' + _SEGMENT_REGISTERED = _KEY_TEMPLATE.format(suffix='registered') + + def __init__(self, uwsgi_entrypoint): + """ + Class constructor. + + :param uwsgi_entrypoint: UWSGI module. Can be the actual module or a mock. + :type uwsgi_entrypoint: module + """ + self._uwsgi = uwsgi_entrypoint + + def get(self, segment_name): + """ + Retrieve a segment. + + :param segment_name: Name of the segment to fetch. + :type segment_name: str + + :return: Parsed segment if present. None otherwise. + :rtype: splitio.models.segments.Segment + """ + key = self._SEGMENT_DATA_KEY_TEMPLATE.format(segment_name=segment_name) + cn_key = self._SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE.format(segment_name=segment_name) + try: + segment_data = json.loads(self._uwsgi.cache_get(key, _SPLITIO_SEGMENTS_CACHE_NAMESPACE)) + change_number = json.loads(self._uwsgi.cache_get(cn_key, _SPLITIO_CHANGE_NUMBERS)) + return segments.from_raw({ + 'name': segment_name, + 'added': segment_data, + 'removed': [], + 'till': change_number + }) + except TypeError: + return None + + def update(self, segment_name, to_add, to_remove, change_number=None): + """ + Update a segment. + + :param segment_name: Name of the segment to update. + :type segment_name: str + :param to_add: List of members to add to the segment. + :type to_add: list + :param to_remove: List of members to remove from the segment. + :type to_remove: list + """ + key = self._SEGMENT_DATA_KEY_TEMPLATE.format(segment_name=segment_name) + try: + segment_data = json.loads(self._uwsgi.cache_get(key, _SPLITIO_SEGMENTS_CACHE_NAMESPACE)) + except TypeError: + segment_data = [] + updated = set(segment_data).union(set(to_add)).difference(to_remove) + self._uwsgi.cache_update( + key, + json.dumps(list(updated)), + 0, + _SPLITIO_SEGMENTS_CACHE_NAMESPACE + ) + + if change_number is not None: + self.set_change_number(segment_name, change_number) + + def put(self, segment): + """ + Put a new segment in storage. + + :param segment: Segment to store. + :type segment: splitio.models.segments.Segent + """ + key = self._SEGMENT_DATA_KEY_TEMPLATE.format(segment_name=segment.name) + self._uwsgi.cache_update( + key, + json.dumps(list(segment.keys)), + 0, + _SPLITIO_SEGMENTS_CACHE_NAMESPACE + ) + self.set_change_number(segment.name, segment.change_number) + + + def get_change_number(self, segment_name): + """ + Retrieve latest change number for a segment. + + :param segment_name: Name of the segment. + :type segment_name: str + + :rtype: int + """ + cnkey = self._SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE.format(segment_name=segment_name) + try: + return json.loads(self._uwsgi.cache_get(cnkey, _SPLITIO_CHANGE_NUMBERS)) + + except TypeError: + return None + + def set_change_number(self, segment_name, new_change_number): + """ + Set the latest change number. + + :param segment_name: Name of the segment. + :type segment_name: str + :param new_change_number: New change number. + :type new_change_number: int + """ + cn_key = self._SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE.format(segment_name=segment_name) + self._uwsgi.cache_update(cn_key, json.dumps(new_change_number), 0, _SPLITIO_CHANGE_NUMBERS) + + def segment_contains(self, segment_name, key): + """ + Check whether a specific key belongs to a segment in storage. + + :param segment_name: Name of the segment to search in. + :type segment_name: str + :param key: Key to search for. + :type key: str + + :return: True if the segment contains the key. False otherwise. + :rtype: bool + """ + segment = self.get(segment_name) + return segment.contains(key) + + +class UWSGIImpressionStorage(ImpressionStorage): + """Impressions storage interface.""" + + _IMPRESSIONS_KEY = 'SPLITIO.impressions.' + _LOCK_IMPRESSION_KEY = 'SPLITIO.impressions_lock' + _IMPRESSIONS_FLUSH = 'SPLITIO.impressions_flush' + _OVERWRITE_LOCK_SECONDS = 5 + + def __init__(self, adapter): + """ + Class Constructor. + + :param adapter: UWSGI Adapter/Emulator/Module. + :type: object + """ + self._uwsgi = adapter + + def put(self, impressions): + """ + Put one or more impressions in storage. + + :param impressions: List of one or more impressions to store. + :type impressions: list + """ + with UWSGILock(self._uwsgi, self._LOCK_IMPRESSION_KEY): + try: + current = json.loads(self._uwsgi.cache_get( + self._IMPRESSIONS_KEY, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE + )) + except TypeError: + current = [] + + self._uwsgi.cache_update( + self._IMPRESSIONS_KEY, + json.dumps(current + [i._asdict() for i in impressions]), + 0, + _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE + ) + + def pop_many(self, count): + """ + Pop the oldest N impressions from storage. + + :param count: Number of impressions to pop. + :type count: int + """ + with UWSGILock(self._uwsgi, self._LOCK_IMPRESSION_KEY): + try: + current = json.loads(self._uwsgi.cache_get( + self._IMPRESSIONS_KEY, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE + )) + except TypeError: + current = [] + self._uwsgi.cache_update( + self._IMPRESSIONS_KEY, + json.dumps(current[count:]), + 0, + _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE + ) + + return [ + Impression( + impression['matching_key'], + impression['feature_name'], + impression['treatment'], + impression['label'], + impression['change_number'], + impression['bucketing_key'], + impression['time'] + ) for impression in current[:count] + ] + + def request_flush(self): + """Set a marker in the events cache to indicate that a flush has been requested.""" + self._uwsgi.cache_set(self._IMPRESSIONS_FLUSH, 'ok', 0, _SPLITIO_LOCK_CACHE_NAMESPACE) + + def should_flush(self): + """ + Return True if a flush has been requested. + + :return: Whether a flush has been requested. + :rtype: bool + """ + value = self._uwsgi.cache_get(self._IMPRESSIONS_FLUSH, _SPLITIO_LOCK_CACHE_NAMESPACE) + return True if value is not None else False + + def acknowledge_flush(self): + """Acknowledge that a flush has been requested.""" + self._uwsgi.cache_del(self._IMPRESSIONS_FLUSH, _SPLITIO_LOCK_CACHE_NAMESPACE) + + +class UWSGIEventStorage(EventStorage): + """Events storage interface.""" + + _EVENTS_KEY = 'events' + _LOCK_EVENTS_KEY = 'events_lock' + _EVENTS_FLUSH = 'events_flush' + _OVERWRITE_LOCK_SECONDS = 5 + + def __init__(self, adapter): + """ + Class Constructor. + + :param adapter: UWSGI Adapter/Emulator/Module. + :type: object + """ + self._uwsgi = adapter + + def put(self, events): + """ + Put one or more events in storage. + + :param events: List of one or more events to store. + :type events: list + """ + with UWSGILock(self._uwsgi, self._LOCK_EVENTS_KEY): + try: + current = json.loads(self._uwsgi.cache_get( + self._EVENTS_KEY, _SPLITIO_EVENTS_CACHE_NAMESPACE + )) + except TypeError: + current = [] + self._uwsgi.cache_update( + self._EVENTS_KEY, + json.dumps(current + [e._asdict() for e in events]), + 0, + _SPLITIO_EVENTS_CACHE_NAMESPACE + ) + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + with UWSGILock(self._uwsgi, self._LOCK_EVENTS_KEY): + try: + current = json.loads(self._uwsgi.cache_get( + self._EVENTS_KEY, _SPLITIO_EVENTS_CACHE_NAMESPACE + )) + except TypeError: + current = [] + self._uwsgi.cache_update( + self._EVENTS_KEY, + json.dumps(current[count:]), + 0, + _SPLITIO_EVENTS_CACHE_NAMESPACE + ) + return [ + Event( + event['key'], + event['traffic_type_name'], + event['event_type_id'], + event['value'], + event['timestamp'] + ) for event in current[:count] + ] + + def request_flush(self): + """Set a marker in the events cache to indicate that a flush has been requested.""" + self._uwsgi.cache_set(self._EVENTS_FLUSH, 'requested', 0, _SPLITIO_LOCK_CACHE_NAMESPACE) + + def should_flush(self): + """ + Return True if a flush has been requested. + + :return: Whether a flush has been requested. + :rtype: bool + """ + value = self._uwsgi.cache_get(self._EVENTS_FLUSH, _SPLITIO_LOCK_CACHE_NAMESPACE) + return True if value is not None else False + + def acknowledge_flush(self): + """Acknowledge that a flush has been requested.""" + self._uwsgi.cache_del(self._EVENTS_FLUSH, _SPLITIO_LOCK_CACHE_NAMESPACE) + + +class UWSGITelemetryStorage(TelemetryStorage): + """Telemetry storage interface.""" + + _LATENCIES_KEY = 'SPLITIO.latencies' + _GAUGES_KEY = 'SPLITIO.gauges' + _COUNTERS_KEY = 'SPLITIO.counters' + + _LATENCIES_LOCK_KEY = 'SPLITIO.latencies.lock' + _GAUGES_LOCK_KEY = 'SPLITIO.gauges.lock' + _COUNTERS_LOCK_KEY = 'SPLITIO.counters.lock' + + def __init__(self, uwsgi_entrypoint): + """ + Class constructor. + + :param uwsgi_entrypoint: uwsgi module/emulator + :type uwsgi_entrypoint: object + """ + self._uwsgi = uwsgi_entrypoint + self._logger = logging.getLogger(self.__class__.__name__) + + + def inc_latency(self, name, bucket): + """ + Add a latency. + + :param name: Name of the latency metric. + :type name: str + :param value: Value of the latency metric. + :tyoe value: int + """ + if not 0 <= bucket <= 21: + self._logger.error('Incorect bucket "%d" for latency "%s". Ignoring.', bucket, name) + return + + with UWSGILock(self._uwsgi, self._LATENCIES_LOCK_KEY): + latencies_raw = self._uwsgi.cache_get(self._LATENCIES_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + latencies = json.loads(latencies_raw) if latencies_raw else {} + to_update = latencies.get(name, [0] * 22) + to_update[bucket] += 1 + latencies[name] = to_update + self._uwsgi.cache_set( + self._LATENCIES_KEY, + json.dumps(latencies), + 0, + _SPLITIO_METRICS_CACHE_NAMESPACE + ) + + def inc_counter(self, name): + """ + Increment a counter. + + :param name: Name of the counter metric. + :type name: str + """ + with UWSGILock(self._uwsgi, self._COUNTERS_LOCK_KEY): + counters_raw = self._uwsgi.cache_get(self._COUNTERS_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + counters = json.loads(counters_raw) if counters_raw else {} + value = counters.get(name, 0) + value += 1 + counters[name] = value + self._uwsgi.cache_set( + self._COUNTERS_KEY, + json.dumps(counters), + 0, + _SPLITIO_METRICS_CACHE_NAMESPACE + ) + + def put_gauge(self, name, value): + """ + Add a gauge metric. + + :param name: Name of the gauge metric. + :type name: str + :param value: Value of the gauge metric. + :type value: int + """ + with UWSGILock(self._uwsgi, self._GAUGES_LOCK_KEY): + gauges_raw = self._uwsgi.cache_get(self._GAUGES_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + gauges = json.loads(gauges_raw) if gauges_raw else {} + gauges[name] = value + self._uwsgi.cache_set( + self._GAUGES_KEY, + json.dumps(gauges), + 0, + _SPLITIO_METRICS_CACHE_NAMESPACE + ) + + def pop_counters(self): + """ + Get all the counters. + + :rtype: list + """ + with UWSGILock(self._uwsgi, self._COUNTERS_LOCK_KEY): + counters_raw = self._uwsgi.cache_get(self._COUNTERS_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + self._uwsgi.cache_del(self._COUNTERS_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + return json.loads(counters_raw) if counters_raw else {} + + def pop_gauges(self): + """ + Get all the gauges. + + :rtype: list + + """ + with UWSGILock(self._uwsgi, self._GAUGES_LOCK_KEY): + gauges_raw = self._uwsgi.cache_get(self._GAUGES_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + self._uwsgi.cache_del(self._GAUGES_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + return json.loads(gauges_raw) if gauges_raw else {} + + def pop_latencies(self): + """ + Get all latencies. + + :rtype: list + """ + with UWSGILock(self._uwsgi, self._LATENCIES_LOCK_KEY): + latencies_raw = self._uwsgi.cache_get(self._LATENCIES_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + self._uwsgi.cache_del(self._LATENCIES_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE) + return json.loads(latencies_raw) if latencies_raw else {} diff --git a/splitio/tasks.py b/splitio/tasks.py deleted file mode 100644 index 03a9d2f9..00000000 --- a/splitio/tasks.py +++ /dev/null @@ -1,273 +0,0 @@ -""" -This module contains everything related to update tasks -""" - -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -import logging -from traceback import print_exc - -from six.moves import queue - -from .splits import Status -from .impressions import build_impressions_data -from . import asynctask -from . import events - - -_logger = logging.getLogger(__name__) - - -def update_segments(segment_cache, segment_change_fetcher): - """ - If updates are enabled, this function updates all the segments listed - by the segment cache get_registered_segments() method making requests to - the Split.io SDK API. If an exception is raised, the process is stopped and - it won't try to update segments again until enabled_updates - is called on the segment cache. - """ - try: - if not segment_cache.is_enabled(): - return - - registered_segments = segment_cache.get_registered_segments() - for name in registered_segments: - update_segment(segment_cache, name, segment_change_fetcher) - except: - _logger.error('Error updating segment definitions') - segment_cache.disable() - - -def update_segment(segment_cache, segment_name, segment_change_fetcher): - """ - Updates a segment. It will eagerly request all changes until the change - number is the same "till" value in the response. - :param segment_name: The name of the segment - :type segment_name: str - """ - till = segment_cache.get_change_number(segment_name) - _logger.info("Updating segment %s" % segment_name) - while True: - response = segment_change_fetcher.fetch(segment_name, till) - _logger.info("SEGMENT RESPONSE %s" % response) - if 'till' not in response: - return - - if till >= response['till']: - return - - if len(response['removed']) > 0: - segment_cache.remove_keys_from_segment( - segment_name, - response['removed'] - ) - - if len(response['added']) > 0: - segment_cache.add_keys_to_segment(segment_name, response['added']) - - segment_cache.set_change_number(segment_name, response['till']) - - till = response['till'] - - -def update_splits(split_cache, split_change_fetcher, split_parser): - """ - If updates are enabled, this function updates (or initializes) the current - cached split configuration. It can be called by periodic update tasks or - directly to force an unscheduled update. - If an exception is raised, the process is stopped and it won't try to - update splits again until enabled_updates is called on the splits cache. - """ - added_features = [] - removed_features = [] - try: - till = split_cache.get_change_number() - - while True: - response = split_change_fetcher.fetch(till) - - if 'till' not in response: - break - - if till >= response['till']: - _logger.debug("change_number is greater or equal than 'till'") - break - - if 'splits' in response and len(response['splits']) > 0: - _logger.debug( - "Splits field in response. response = %s", - response - ) - - for split_change in response['splits']: - if Status(split_change['status']) != Status.ACTIVE: - split_cache.remove_split(split_change['name']) - removed_features.append(split_change['name']) - continue - - parsed_split = split_parser.parse(split_change) - if parsed_split is None: - _logger.warning( - 'We could not parse the split definition for %s. ' - 'Removing split to be safe.', split_change['name']) - split_cache.remove_split(split_change['name']) - removed_features.append(split_change['name']) - continue - - added_features.append(split_change['name']) - split_cache.add_split(split_change['name'], split_change) - - if len(added_features) > 0: - _logger.info('Updated features: %s', added_features) - - if len(removed_features) > 0: - _logger.info('Deleted features: %s', removed_features) - - till = response['till'] - split_cache.set_change_number(response['till']) - except: - _logger.error('Error updating split definitions') - split_cache.disable() - return [], [] - - return added_features, removed_features - - -def report_impressions(impressions_cache, sdk_api): - """ - If the reporting process is enabled (through the impressions cache), - this function collects the impressions from the cache and sends them to - Split through the events API. If the process fails, no exceptions are - raised (but they are logged) and the process is disabled. - """ - try: - if not impressions_cache.is_enabled(): - return - - impressions = impressions_cache.fetch_all_and_clear() - test_impressions_data = build_impressions_data(impressions) - - _logger.debug('Impressions to send: %s' % test_impressions_data) - - if len(test_impressions_data) > 0: - _logger.info( - 'Posting impressions for features: %s.', - ', '.join(impressions.keys()) - ) - sdk_api.test_impressions(test_impressions_data) - except: - _logger.error('Error reporting impressions. Disabling impressions log.') - impressions_cache.disable() - - -def report_metrics(metrics_cache, sdk_api): - """ - If the reporting process is enabled (through the metrics cache), - this function collects the time, count and gauge from the cache and sends - them to Split through the events API. If the process fails, no exceptions - are raised (but they are logged) and the process is disabled. - """ - try: - if not metrics_cache.is_enabled(): - return - - time = metrics_cache.fetch_all_times_and_clear() - if len(time) > 0: - _logger.info('Sending times metrics...') - sdk_api.metrics_times(time) - - metrics = metrics_cache.fetch_all_and_clear() - if 'count' in metrics and len(metrics['count']) > 0: - _logger.info('Sending counters metrics...') - sdk_api.metrics_counters(metrics['count']) - - if 'gauge' in metrics and len(metrics['gauge']) > 0: - _logger.info('Sending gauge metrics...') - sdk_api.metrics_gauge(metrics['gauge']) - except: - _logger.error('Error reporting metrics') - metrics_cache.disable() - - -class EventsSyncTask: - """ - Events synchronization task uses an asynctask.AsyncTask to send events - periodically to the backend in a controlled way - """ - - def __init__(self, sdk_api, storage, period, bulk_size): - self._sdk_api = sdk_api - self._storage = storage - self._period = period - self._failed = queue.Queue() - self._bulk_size = bulk_size - self._task = asynctask.AsyncTask( - self._send_events, - self._period, - on_stop=self._send_events, - ) - - def _get_failed(self): - """ - Return up to events stored in the failed eventes queue - """ - events = [] - n = 0 - while n < self._bulk_size: - try: - events.append(self._failed.get(False)) - except queue.Empty: - # If no more items in queue, break the loop - break - return events - - def _add_to_failed_queue(self, events): - """ - Add events that were about to be sent to a secondary queue for failed sends - """ - for e in events: - self._failed.put(e, False) - - def _send_events(self): - """ - Grabs events from the failed queue (and new ones if the bulk size - is not met) and submits them to the backend - """ - - to_send = self._get_failed() - if len(to_send) < self._bulk_size: - # If the amount of previously failed items is less than the bulk - # size, try to complete with new events from storage - to_send.extend(self._storage.pop_many(self._bulk_size - len(to_send))) - - if len(to_send) == 0: - return - - try: - status_code = self._sdk_api.track_events(events.build_bulk(to_send)) - if status_code >= 300: - _logger.error("Event reporting failed with status code {}".format(status_code)) - self._add_to_failed_queue(to_send) - except Exception: - # Something went wrong - _logger.error("Exception raised while reporting events") - self._add_to_failed_queue(to_send) - - def start(self): - """ - Start executing the events synchronization task - """ - self._task.start() - - def stop(self): - """ - Stop executing the events synchronization task - """ - self._task.stop() - - def flush(self): - """ - Flush events in storage - """ - self._task.force_execution() diff --git a/splitio/tasks/__init__.py b/splitio/tasks/__init__.py new file mode 100644 index 00000000..7d478a22 --- /dev/null +++ b/splitio/tasks/__init__.py @@ -0,0 +1,30 @@ +"""Split synchronization tasks module.""" + +import abc + +class BaseSynchronizationTask(object): + """Syncrhonization task interface.""" + + __metadata__ = abc.ABCMeta + + @abc.abstractmethod + def start(self): + """Start the task.""" + pass + + @abc.abstractmethod + def stop(self, event=None): + """ + Stop the task if running. + + Optionally accept an event to be set when the task finally stops. + + :param event: Event to be set as soon as the task finishes. + :type event: Threading.Event + """ + pass + + @abc.abstractmethod + def is_running(self): + """Return true if the task is running, false otherwise.""" + pass diff --git a/splitio/tasks/events_sync.py b/splitio/tasks/events_sync.py new file mode 100644 index 00000000..3dac401a --- /dev/null +++ b/splitio/tasks/events_sync.py @@ -0,0 +1,103 @@ +"""Events syncrhonization task.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import logging + +from six.moves import queue +from splitio.api import APIException +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util.asynctask import AsyncTask + + +class EventsSyncTask(BaseSynchronizationTask): + """Events synchronization task uses an asynctask.AsyncTask to send events.""" + + def __init__(self, events_api, storage, period, bulk_size): + """ + Class constructor. + + :param events_api: Events Api object to send data to the backend + :type events_api: splitio.api.events.EventsAPI + :param storage: Events Storage + :type storage: splitio.storage.EventStorage + :param period: How many seconds to wait between subsequent event pushes to the BE. + :type period: int + :param bulk_size: How many events to send per push. + :type bulk_size: int + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._events_api = events_api + self._storage = storage + self._period = period + self._failed = queue.Queue() + self._bulk_size = bulk_size + self._task = AsyncTask(self._send_events, self._period, on_stop=self._send_events) + + def _get_failed(self): + """Return up to events stored in the failed eventes queue.""" + events = [] + count = 0 + while count < self._bulk_size: + try: + events.append(self._failed.get(False)) + count += 1 + except queue.Empty: + # If no more items in queue, break the loop + break + return events + + def _add_to_failed_queue(self, events): + """ + Add events that were about to be sent to a secondary queue for failed sends. + + :param events: List of events that failed to be pushed. + :type events: list + """ + for event in events: + self._failed.put(event, False) + + def _send_events(self): + """Send events from both the failed and new queues.""" + to_send = self._get_failed() + if len(to_send) < self._bulk_size: + # If the amount of previously failed items is less than the bulk + # size, try to complete with new events from storage + to_send.extend(self._storage.pop_many(self._bulk_size - len(to_send))) + + if not to_send: + return + + try: + status_code = self._events_api.flush_events(to_send) + if status_code >= 300: + self._logger.error("Event reporting failed with status code %d", status_code) + self._add_to_failed_queue(to_send) + except APIException as exc: + self._logger.error( + 'Exception raised while reporting events: %s -- %d', + exc.custom_message, + exc.status_code + ) + self._add_to_failed_queue(to_send) + + def start(self): + """Start executing the events synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the events synchronization task.""" + self._task.stop(event) + + def flush(self): + """Flush events in storage.""" + self._task.force_execution() + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() diff --git a/splitio/tasks/impressions_sync.py b/splitio/tasks/impressions_sync.py new file mode 100644 index 00000000..cc3567f2 --- /dev/null +++ b/splitio/tasks/impressions_sync.py @@ -0,0 +1,104 @@ +"""Impressions syncrhonization task.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +import logging + +from six.moves import queue + +from splitio.api import APIException +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util.asynctask import AsyncTask + + +class ImpressionsSyncTask(BaseSynchronizationTask): + """Impressions synchronization task uses an asynctask.AsyncTask to send impressions.""" + + def __init__(self, impressions_api, storage, period, bulk_size): + """ + Class constructor. + + :param impressions_api: Impressions Api object to send data to the backend + :type impressions_api: splitio.api.impressions.ImpressionsAPI + :param storage: Impressions Storage + :type storage: splitio.storage.ImpressionsStorage + :param period: How many seconds to wait between subsequent impressions pushes to the BE. + :type period: int + :param bulk_size: How many impressions to send per push. + :type bulk_size: int + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._impressions_api = impressions_api + self._storage = storage + self._period = period + self._failed = queue.Queue() + self._bulk_size = bulk_size + self._task = AsyncTask(self._send_impressions, self._period, on_stop=self._send_impressions) + + def _get_failed(self): + """Return up to impressions stored in the failed impressions queue.""" + imps = [] + count = 0 + while count < self._bulk_size: + try: + imps.append(self._failed.get(False)) + count += 1 + except queue.Empty: + # If no more items in queue, break the loop + break + return imps + + def _add_to_failed_queue(self, imps): + """ + Add impressions that were about to be sent to a secondary queue for failed sends. + + :param imps: List of impressions that failed to be pushed. + :type imps: list + """ + for impression in imps: + self._failed.put(impression, False) + + def _send_impressions(self): + """Send impressions from both the failed and new queues.""" + to_send = self._get_failed() + if len(to_send) < self._bulk_size: + # If the amount of previously failed items is less than the bulk + # size, try to complete with new impressions from storage + to_send.extend(self._storage.pop_many(self._bulk_size - len(to_send))) + + if not to_send: + return + + try: + status_code = self._impressions_api.flush_impressions(to_send) + if status_code >= 300: + self._logger.error("Impressions reporting failed with status code %s", status_code) + self._add_to_failed_queue(to_send) + except APIException as exc: + self._logger.error( + 'Exception raised while reporting impressions: %s -- %d', + exc.custom_message, + exc.status_code + ) + self._add_to_failed_queue(to_send) + + def start(self): + """Start executing the impressions synchronization task.""" + self._task.start() + + def stop(self, event=None): + """Stop executing the impressions synchronization task.""" + self._task.stop(event) + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() + + def flush(self): + """Flush impressions in storage.""" + self._task.force_execution() diff --git a/splitio/tasks/segment_sync.py b/splitio/tasks/segment_sync.py new file mode 100644 index 00000000..871af25f --- /dev/null +++ b/splitio/tasks/segment_sync.py @@ -0,0 +1,106 @@ +"""Segment syncrhonization module.""" + +import logging +from splitio.api import APIException +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util import asynctask, workerpool +from splitio.models import segments + + +class SegmentSynchronizationTask(BaseSynchronizationTask): #pylint: disable=too-many-instance-attributes + """Segment Syncrhonization class.""" + + def __init__(self, segment_api, segment_storage, split_storage, period, event): #pylint: disable=too-many-arguments + """ + Clas constructor. + + :param segment_api: API to retrieve segments from backend. + :type segment_api: splitio.api.SegmentApi + + :param segment_storage: Segment storage reference. + :type segment_storage: splitio.storage.SegmentStorage + + :param event: Event to signal when all segments have finished initial sync. + :type event: threading.Event + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._worker_pool = workerpool.WorkerPool(20, self._ensure_segment_is_updated) + self._task = asynctask.AsyncTask(self._main, period, on_init=self._on_init) + self._segment_api = segment_api + self._segment_storage = segment_storage + self._split_storage = split_storage + self._event = event + self._pending_initialization = [] + + def _update_segment(self, segment_name): + """ + Update a segment by hitting the split backend. + + :param segment_name: Name of the segment to update. + :type segment_name: str + """ + since = self._segment_storage.get_change_number(segment_name) + if since is None: + since = -1 + + try: + segment_changes = self._segment_api.fetch_segment(segment_name, since) + except APIException: + self._logger.error('Error fetching segments') + return False + + if since == -1: # first time fetching the segment + new_segment = segments.from_raw(segment_changes) + self._segment_storage.put(new_segment) + else: + self._segment_storage.update( + segment_name, + segment_changes['added'], + segment_changes['removed'], + segment_changes['till'] + ) + + return segment_changes['till'] == segment_changes['since'] + + def _main(self): + """Submit all current segments and wait for them to finish.""" + segment_names = self._split_storage.get_segment_names() + for segment_name in segment_names: + self._worker_pool.submit_work(segment_name) + + def _on_init(self): + """Submit all current segments and wait for them to finish, then set the ready flag.""" + self._main() + self._worker_pool.wait_for_completion() + self._event.set() + + def _ensure_segment_is_updated(self, segment_name): + """ + Update a segment by hitting the split backend. + + :param segment_name: Name of the segment to update. + :type segment_name: str + """ + while True: + ready = self._update_segment(segment_name) + if ready: + break + + def start(self): + """Start segment synchronization.""" + self._worker_pool.start() + self._task.start() + + def stop(self, event=None): + """Stop segment synchronization.""" + self._task.stop() + self._worker_pool.stop(event) + + def is_running(self): + """ + Return whether the task is running or not. + + :return: True if the task is running. False otherwise. + :rtype: bool + """ + return self._task.running() diff --git a/splitio/tasks/split_sync.py b/splitio/tasks/split_sync.py new file mode 100644 index 00000000..bcdab187 --- /dev/null +++ b/splitio/tasks/split_sync.py @@ -0,0 +1,82 @@ +"""Split Synchronization task.""" + +import logging +from splitio.models import splits +from splitio.api import APIException +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util.asynctask import AsyncTask + + +class SplitSynchronizationTask(BaseSynchronizationTask): + """Split Synchronization task class.""" + + def __init__(self, split_api, split_storage, period, ready_flag): + """ + Class constructor. + + :param split_api: Split API Client. + :type split_api: splitio.api.splits.SplitsAPI + :param split_storage: Split Storage. + :type split_storage: splitio.storage.InMemorySplitStorage + :param ready_flag: Flag to set when splits initial sync is complete. + :type ready_flag: threading.Event + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._api = split_api + self._ready_flag = ready_flag + self._period = period + self._split_storage = split_storage + self._task = AsyncTask(self._update_splits, period, self._on_start) + + def _update_splits(self): + """ + Hit endpoint, update storage and return True if sync is complete. + + :return: True if synchronization is complete. + :rtype: bool + """ + till = self._split_storage.get_change_number() + if till is None: + till = -1 + + try: + split_changes = self._api.fetch_splits(till) + except APIException: + self._logger.error('Failed to fetch split from servers') + return False + + for split in split_changes.get('splits', []): + if split['status'] == splits.Status.ACTIVE.value: + self._split_storage.put(splits.from_raw(split)) + else: + self._split_storage.remove(split['name']) + + self._split_storage.set_change_number(split_changes['till']) + return split_changes['till'] == split_changes['since'] + + def _on_start(self): + """Wait until splits are in sync and set the flag to true.""" + while True: + ready = self._update_splits() + if ready: + break + + self._ready_flag.set() + return True + + def start(self): + """Start the task.""" + self._task.start() + + def stop(self, event=None): + """Stop the task. Accept an optional event to set when the task has finished.""" + self._task.stop(event) + + def is_running(self): + """ + Return whether the task is running. + + :return: True if the task is running. False otherwise. + :rtype bool + """ + return self._task.running() diff --git a/splitio/tasks/telemetry_sync.py b/splitio/tasks/telemetry_sync.py new file mode 100644 index 00000000..96e084ef --- /dev/null +++ b/splitio/tasks/telemetry_sync.py @@ -0,0 +1,70 @@ +"""Split Synchronization task.""" + +import logging +from splitio.api import APIException +from splitio.tasks import BaseSynchronizationTask +from splitio.tasks.util.asynctask import AsyncTask + + +class TelemetrySynchronizationTask(BaseSynchronizationTask): + """Split Synchronization task class.""" + + def __init__(self, api, storage, period): + """ + Class constructor. + + :param api: Telemetry API Client. + :type api: splitio.api.telemetry.TelemetryAPI + :param storage: Telemetry Storage. + :type storage: splitio.storage.InMemoryTelemetryStorage + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._api = api + self._period = period + self._storage = storage + self._task = AsyncTask(self._flush_telemetry, period) + + def _flush_telemetry(self): + """ + Send latencies, counters and gauges to split BE. + + :return: True if synchronization is complete. + :rtype: bool + """ + try: + latencies = self._storage.pop_latencies() + if latencies: + self._api.flush_latencies(latencies) + except APIException: + self._logger.error('Failed send telemetry/latencies to split BE.') + + try: + counters = self._storage.pop_counters() + if counters: + self._api.flush_counters(counters) + except APIException: + self._logger.error('Failed send telemetry/counters to split BE.') + + try: + gauges = self._storage.pop_gauges() + if gauges: + self._api.flush_gauges(gauges) + except APIException: + self._logger.error('Failed send telemetry/gauges to split BE.') + + def start(self): + """Start the task.""" + self._task.start() + + def stop(self, event=None): + """Stop the task. Accept an optional event to set when the task has finished.""" + self._task.stop(event) + + def is_running(self): + """ + Return whether the task is running. + + :return: True if the task is running. False otherwise. + :rtype bool + """ + return self._task.running() diff --git a/splitio/tasks/util/__init__.py b/splitio/tasks/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/asynctask.py b/splitio/tasks/util/asynctask.py similarity index 54% rename from splitio/asynctask.py rename to splitio/tasks/util/asynctask.py index 18545fe1..929202cc 100644 --- a/splitio/asynctask.py +++ b/splitio/tasks/util/asynctask.py @@ -1,6 +1,4 @@ -""" -Asynchronous tasks that can be controlled -""" +"""Asynchronous tasks that can be controlled.""" import threading import logging @@ -15,7 +13,8 @@ def _safe_run(func): """ - Executes a function wrapped in a try-except statement set + Execute a function wrapped in a try-except block. + If anything goes wrong returns false instead of propagating the exception. :param func: Function to be executed, receives no arguments and it's return @@ -24,16 +23,19 @@ def _safe_run(func): try: func() return True - except Exception as exc: + except Exception: #pylint: disable=broad-except # Catch any exception that might happen to avoid the periodic task # from ending and allowing for a recovery, as well as preventing # an exception from propagating and breaking the main thread - _LOGGER.exception(exc) + _LOGGER.error('Something went wrong when running passed function.') + _LOGGER.debug('Original traceback:', exc_info=True) return False -class AsyncTask(object): +class AsyncTask(object): #pylint: disable=too-many-instance-attributes """ + Asyncrhonous controllable task class. + This class creates is used to wrap around a function to treat it as a periodic task. This task can be stopped, it's execution can be forced, and it's status (whether it's running or not) can be obtained from the task @@ -43,12 +45,16 @@ class AsyncTask(object): def __init__(self, main, period, on_init=None, on_stop=None): """ - Class constructor + Class constructor. :param main: Main function to be executed periodically + :type main: callable :param period: How many seconds to wait between executions - :param onInit: Function to be executed ONCE before the main one - :Param onStop: Function to be executed ONCE after the task has finished + :type period: int + :param on_init: Function to be executed ONCE before the main one + :type on_init: callable + :param on_stop: Function to be executed ONCE after the task has finished + :type on_stop: callable """ self._on_init = on_init self._main = main @@ -57,10 +63,12 @@ def __init__(self, main, period, on_init=None, on_stop=None): self._messages = queue.Queue() self._running = False self._thread = None + self._stop_event = None def _execution_wrapper(self): """ - This function will be run in a separate thread. + Execute user defined function in separate thread. + It will execute the "on init" hook is available. If an exception is raised it will abort execution, otherwise it will enter an infinite loop in which the main function is executed every seconds. @@ -70,45 +78,51 @@ def _execution_wrapper(self): All custom functions are run within a _safe_run() function which prevents exceptions from being propagated. """ - if self._on_init is not None: - if not _safe_run(self._on_init): - _LOGGER.error("Error running task initialization function, aborting execution") - return - - while True: - try: - msg = self._messages.get(True, self._period) - if msg == __TASK_STOP__: - _LOGGER.info("Stop signal received. finishing task execution") - break - elif msg == __TASK_FORCE_RUN__: - _LOGGER.info("Force execution signal received. Running now") + try: + if self._on_init is not None: + if not _safe_run(self._on_init): + _LOGGER.error("Error running task initialization function, aborting execution") + self._running = False + return + self._running = True + while True: + try: + msg = self._messages.get(True, self._period) + if msg == __TASK_STOP__: + _LOGGER.info("Stop signal received. finishing task execution") + break + elif msg == __TASK_FORCE_RUN__: + _LOGGER.info("Force execution signal received. Running now") + if not _safe_run(self._main): + _LOGGER.error( + "An error occurred when executing the task. " + "Retrying after perio expires" + ) + continue + except queue.Empty: + # If no message was received, the timeout has expired + # and we're ready for a new execution if not _safe_run(self._main): _LOGGER.error( "An error occurred when executing the task. " "Retrying after perio expires" ) - continue - except queue.Empty: - # If no message was received, the timeout has expired - # and we're ready for a new execution - if not _safe_run(self._main): - _LOGGER.error( - "An error occurred when executing the task. " - "Retrying after perio expires" - ) + finally: + self._cleanup() + def _cleanup(self): + """Execute on_stop callback, set event if needed, update status.""" if self._on_stop is not None: if not _safe_run(self._on_stop): _LOGGER.error("An error occurred when executing the task's OnStop hook. ") self._running = False + if self._stop_event is not None: + self._stop_event.set() + def start(self): - """ - Creates the thread that will execute the function periodically and - starts it. - """ + """Start the async task.""" if self._running: _LOGGER.warning("Task is already running. Ignoring .start() call") return @@ -118,33 +132,38 @@ def start(self): self._thread.setDaemon(True) try: self._thread.start() - self._running = True + except RuntimeError as exc: _LOGGER.error("Couldn't create new thread for async task") _LOGGER.exception(exc) - def stop(self): + def stop(self, event=None): """ - Sends a signal to the thread in order to stop it. - If the task is not running it does nothing. + Send a signal to the thread in order to stop it. If the task is not running do nothing. + + Optionally accept an event to be set upon task completion. + + :param event: Event to set when the task completes. + :type event: threading.Event """ + if event is not None: + self._stop_event = event + if not self._running: + if self._stop_event is not None: + event.set() return + # Queue is of infinite size, should not raise an exception self._messages.put(__TASK_STOP__, False) def force_execution(self): - """ - Forces an execution of the task withouth waiting for the period to end. - If the task is not running it does nothing. - """ + """Force an execution of the task without waiting for the period to end.""" if not self._running: return # Queue is of infinite size, should not raise an exception self._messages.put(__TASK_FORCE_RUN__, False) def running(self): - """ - Returns whether the task is running or not - """ + """Return whether the task is running or not.""" return self._running diff --git a/splitio/tasks/util/workerpool.py b/splitio/tasks/util/workerpool.py new file mode 100644 index 00000000..0499e163 --- /dev/null +++ b/splitio/tasks/util/workerpool.py @@ -0,0 +1,121 @@ +"""Worker pool module.""" + +import logging +from threading import Thread, Event +from six.moves import queue + + +class WorkerPool(object): + """Worker pool class to implement single producer/multiple consumer.""" + + def __init__(self, worker_count, worker_func): + """ + Class constructor. + + :param worker_count: Number of workers for the pool. + :type worker_func: Function to be executed by the workers whenever a messages is fetched. + """ + self._logger = logging.getLogger(self.__class__.__name__) + self._incoming = queue.Queue() + self._should_be_working = [True for _ in range(0, worker_count)] + self._worker_events = [Event() for _ in range(0, worker_count)] + self._threads = [ + Thread(target=self._wrapper, args=(i, worker_func)) + for i in range(0, worker_count) + ] + for thread in self._threads: + thread.setDaemon(True) + + def start(self): + """Start the workers.""" + for thread in self._threads: + thread.start() + + def _safe_run(self, func, message): + """ + Execute the user funcion for a given message without raising exceptions. + + :param func: User defined function. + :type func: callable + :param message: Message fetched from the queue. + :param message: object + + :return True if no everything goes well. False otherwise. + :rtype bool + """ + try: + func(message) + return True + except Exception: #pylint: disable=broad-except + self._logger.error("Something went wrong when processing message %s", message) + self._logger.debug('Original traceback: ', exc_info=True) + return False + + def _wrapper(self, worker_number, func): + """ + Fetch message, execute tasks, and acknowledge results. + + :param worker_number: # (id) of worker whose function will be executed. + :type worker_number: int + :param func: User defined function. + :type func: callable. + """ + while self._should_be_working[worker_number]: + try: + message = self._incoming.get(True, 0.5) + self._incoming.task_done() + + # For some reason message can be None in python2 implementation of queue. + # This method must be both ignored and acknowledged with .task_done() + # otherwise .join() will halt. + if message is None: + continue + + ok = self._safe_run(func, message) #pylint: disable=invalid-name + if not ok: + self._logger.error( + ("Something went wrong during the execution, " + "removing message \"%s\" from queue."), + message + ) + except queue.Empty: + # No message was fetched, just keep waiting. + pass + + # Set my flag indicating that i have finished + self._worker_events[worker_number].set() + + def submit_work(self, message): + """ + Add a new message to the work-queue. + + :param message: New message to add. + :type message: object. + """ + self._incoming.put(message) + + def wait_for_completion(self): + """Block until the work queue is empty.""" + self._incoming.join() + + def stop(self, event=None): + """Stop all worker nodes.""" + async_stop = Thread(target=self._wait_workers_shutdown, args=(event,)) + async_stop.setDaemon(True) + async_stop.start() + + def _wait_workers_shutdown(self, event): + """ + Wait until all workers have finished, and set the event. + + :param event: Event to set as soon as all the workers have shut down. + :type event: threading.Event + """ + self.wait_for_completion() + for index, _ in enumerate(self._should_be_working): + self._should_be_working[index] = False + + if event is not None: + for worker_event in self._worker_events: + worker_event.wait() + event.set() diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py new file mode 100644 index 00000000..90c40746 --- /dev/null +++ b/splitio/tasks/uwsgi_wrappers.py @@ -0,0 +1,186 @@ +"""Wrappers for tasks when using UWSGI Cache as a synchronization platform.""" + +import logging +import time +from splitio.client.config import DEFAULT_CONFIG +from splitio.client.util import get_metadata +from splitio.storage.adapters.uwsgi_cache import get_uwsgi +from splitio.storage.uwsgi import UWSGIEventStorage, UWSGIImpressionStorage, \ + UWSGISegmentStorage, UWSGISplitStorage, UWSGITelemetryStorage +from splitio.api.client import HttpClient +from splitio.api.splits import SplitsAPI +from splitio.api.segments import SegmentsAPI +from splitio.api.impressions import ImpressionsAPI +from splitio.api.telemetry import TelemetryAPI +from splitio.api.events import EventsAPI +from splitio.tasks.split_sync import SplitSynchronizationTask +from splitio.tasks.segment_sync import SegmentSynchronizationTask +from splitio.tasks.impressions_sync import ImpressionsSyncTask +from splitio.tasks.events_sync import EventsSyncTask +from splitio.tasks.telemetry_sync import TelemetrySynchronizationTask + + +_LOGGER = logging.getLogger(__name__) + + +def _get_config(user_config): + """ + Get sdk configuration using defaults + user overrides. + + :param user_config: User configuration. + :type user_config: dict + + :return: Calculated configuration. + :rtype: dict + """ + sdk_config = DEFAULT_CONFIG + sdk_config.update(user_config) + return sdk_config + + +def uwsgi_update_splits(user_config): + """ + Update splits task. + + :param user_config: User-provided configuration. + :type user_config: dict + """ + try: + config = _get_config(user_config) + seconds = config['featuresRefreshRate'] + split_sync_task = SplitSynchronizationTask( + SplitsAPI( + HttpClient(config.get('sdk_url'), config.get('events_url')), config['apikey'] + ), + UWSGISplitStorage(get_uwsgi), + None, # Time not needed since the task will be triggered manually. + None # Ready flag not needed since it will never be set and consumed. + ) + + while True: + split_sync_task._update_splits() #pylint: disable=protected-access + time.sleep(seconds) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error updating splits') + + +def uwsgi_update_segments(user_config): + """ + Update segments task. + + :param user_config: User-provided configuration. + :type user_config: dict + """ + try: + config = _get_config(user_config) + seconds = config['segmentsRefreshRate'] + segment_sync_task = SegmentSynchronizationTask( + SegmentsAPI( + HttpClient(config.get('sdk_url'), config.get('events_url')), config['apikey'] + ), + UWSGISegmentStorage(get_uwsgi()), + None, # Split sotrage not needed, segments provided manually, + None, # Period not needed, task executed manually + None # Flag not needed, never consumed or set. + ) + split_storage = UWSGISplitStorage(get_uwsgi()) + while True: + segment_names = split_storage.get_segment_names() + for segment_name in segment_names: + segment_sync_task._update_segment(segment_name) #pylint: disable=protected-access + time.sleep(seconds) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error updating segments') + + +def uwsgi_report_impressions(user_config): + """ + Flush impressions task. + + :param user_config: User-provided configuration. + :type user_config: dict + """ + try: + config = _get_config(user_config) + metadata = get_metadata(config) + seconds = config['impressionsRefreshRate'] + storage = UWSGIImpressionStorage(get_uwsgi()) + impressions_sync_task = ImpressionsSyncTask( + ImpressionsAPI( + HttpClient(config.get('sdk_url'), config.get('events_url')), + config['apikey'], + metadata + ), + storage, + None, # Period not needed. Task is being triggered manually. + 5000 # TODO: Parametrize! + ) + + while True: + impressions_sync_task._send_impressions() #pylint: disable=protected-access + for _ in xrange(0, seconds): + if storage.should_flush(): + storage.acknowledge_flush() + break + time.sleep(1) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error posting impressions') + +def uwsgi_report_events(user_config): + """ + Flush events task. + + :param user_config: User-provided configuration. + :type user_config: dict + """ + try: + config = _get_config(user_config) + metadata = get_metadata(config) + seconds = config.get('eventsRefreshRate', 30) + storage = UWSGIEventStorage(get_uwsgi()) + task = EventsSyncTask( + EventsAPI( + HttpClient(config.get('sdk_url'), config.get('events_url')), + config['apikey'], + metadata + ), + storage, + None, # Period not needed. Task is being triggered manually. + 5000 # TODO: Parametrize + ) + while True: + task._send_events() #pylint: disable=protected-access + for _ in xrange(0, seconds): + if storage.should_flush(): + storage.acknowledge_flush() + break + time.sleep(1) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error posting metrics') + +def uwsgi_report_telemetry(user_config): + """ + Flush events task. + + :param user_config: User-provided configuration. + :type user_config: dict + """ + try: + config = _get_config(user_config) + metadata = get_metadata(config) + seconds = config.get('metricsRefreshRate', 30) + storage = UWSGITelemetryStorage(get_uwsgi()) + task = TelemetrySynchronizationTask( + TelemetryAPI( + HttpClient(config.get('sdk_url'), config.get('events_url')), + config['apikey'], + metadata + ), + storage, + None, # Period not needed. Task is being triggered manually. + ) + while True: + task._flush_telemetry() #pylint: disable=protected-access + time.sleep(seconds) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error posting metrics') diff --git a/splitio/tests/algoSplits.json b/splitio/tests/algoSplits.json deleted file mode 100644 index 67800171..00000000 --- a/splitio/tests/algoSplits.json +++ /dev/null @@ -1,264 +0,0 @@ -{ - "splits": [ - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "some_feature_1", - "changeNumber":1325599980, - "algo": 1, - "seed": -1222652054, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "WHITELIST", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": { - "whitelist": [ - "whitelisted_user" - ] - } - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - } - ] - }, - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "ALL_KEYS", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 0 - }, - { - "treatment": "off", - "size": 100 - } - ] - } - ] - }, - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "some_feature_2", - "algo": 2, - "changeNumber":1325599980, - "seed": 1699838640, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "ALL_KEYS", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - }, - { - "treatment": "off", - "size": 0 - } - ] - } - ] - }, - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "some_feature_3", - "algo": null, - "changeNumber":1325599980, - "seed": -480091424, - "status": "ACTIVE", - "killed": true, - "defaultTreatment": "defTreatment", - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "ALL_KEYS", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "defTreatment", - "size": 100 - }, - { - "treatment": "off", - "size": 0 - } - ] - } - ] - }, - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "some_feature_4", - "seed": 1548363147, - "changeNumber":1325599980, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "IN_SEGMENT", - "negate": false, - "userDefinedSegmentMatcherData": { - "segmentName": "employees" - }, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - } - ] - }, - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "IN_SEGMENT", - "negate": false, - "userDefinedSegmentMatcherData": { - "segmentName": "human_beigns" - }, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 30 - }, - { - "treatment": "off", - "size": 70 - } - ] - } - ] - }, - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "some_feature_5", - "algo": 38, - "seed": 1548363147, - "changeNumber":1325599980, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "IN_SEGMENT", - "negate": false, - "userDefinedSegmentMatcherData": { - "segmentName": "employees" - }, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - } - ] - }, - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "matcherType": "IN_SEGMENT", - "negate": false, - "userDefinedSegmentMatcherData": { - "segmentName": "human_beigns" - }, - "whitelistMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 30 - }, - { - "treatment": "off", - "size": 70 - } - ] - } - ] - } - ], - "since": -1, - "till": 1457726098069 -} diff --git a/splitio/tests/segmentChanges.json b/splitio/tests/segmentChanges.json deleted file mode 100644 index eaa6f2fa..00000000 --- a/splitio/tests/segmentChanges.json +++ /dev/null @@ -1 +0,0 @@ -{"name":"demo","added":["fake_id_1","fake_id_2","fake_id_3"],"removed":[],"since":-1,"till":1461874240706} diff --git a/splitio/tests/splitChanges.json b/splitio/tests/splitChanges.json deleted file mode 100644 index 4721c931..00000000 --- a/splitio/tests/splitChanges.json +++ /dev/null @@ -1,1440 +0,0 @@ -{ - "splits":[ - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_multi_condition", - "seed":-1329591480, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":null - }, - "matcherType":"IN_SEGMENT", - "negate":false, - "userDefinedSegmentMatcherData":{ - "segmentName":"demo" - }, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"NUMBER", - "value":42 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_whitelist", - "seed":-1746200186, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "itsy", - "bitsy", - "spider" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_less_than_or_equal_to_datetime", - "seed":-223738659, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"LESS_THAN_OR_EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"DATETIME", - "value":1461585600000 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_in_segment_multi_treatment", - "seed":-1733385168, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_some_treatment" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"some_treatment", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":null - }, - "matcherType":"IN_SEGMENT", - "negate":false, - "userDefinedSegmentMatcherData":{ - "segmentName":"demo" - }, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":25 - }, - { - "treatment":"off", - "size":25 - }, - { - "treatment":"some_treatment", - "size":50 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_greatr_than_or_equal_to_number", - "seed":790203375, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"GREATER_THAN_OR_EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"NUMBER", - "value":42 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_greatr_than_or_equal_to_datetime", - "seed":-899943577, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"GREATER_THAN_OR_EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"DATETIME", - "value":1461585600000 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_equal_to_datetime", - "seed":852356271, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"DATETIME", - "value":1461542400000 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_killed", - "seed":-1268386077, - "status":"ACTIVE", - "killed":true, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":null - }, - "matcherType":"IN_SEGMENT", - "negate":false, - "userDefinedSegmentMatcherData":{ - "segmentName":"demo" - }, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_equal_to_number", - "seed":619910976, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"NUMBER", - "value":50 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_in_segment", - "seed":-433206284, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":null - }, - "matcherType":"IN_SEGMENT", - "negate":false, - "userDefinedSegmentMatcherData":{ - "segmentName":"demo" - }, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_less_than_or_equal_to_number", - "seed":-978018717, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"LESS_THAN_OR_EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"NUMBER", - "value":42 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_between_datetime", - "seed":158853511, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"BETWEEN", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":{ - "dataType":"DATETIME", - "start":1461585600000, - "end":1461673260000 - } - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_between_number", - "seed":-357084756, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"BETWEEN", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":{ - "dataType":"NUMBER", - "start":40, - "end":50 - } - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }, - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_in_segment_update", - "seed":264664699, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":null - }, - "matcherType":"IN_SEGMENT", - "negate":false, - "userDefinedSegmentMatcherData":{ - "segmentName":"test_segment_update" - }, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":40 - }, - { - "treatment":"off", - "size":60 - } - ] - } - ] - } - ], - "since":-1, - "till":1461957424937 -} \ No newline at end of file diff --git a/splitio/tests/splitChangesReadOnly.json b/splitio/tests/splitChangesReadOnly.json deleted file mode 100644 index 7b64d71b..00000000 --- a/splitio/tests/splitChangesReadOnly.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "splits": [ - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "test_read_only_1", - "seed": -1329591480, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "changeNumber": 1325599980, - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "keySelector": null, - "matcherType": "WHITELIST", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": { - "whitelist": [ - "valid" - ] - }, - "unaryNumericMatcherData": null, - "betweenMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - } - ] - } - ] - } - ], - "since": -1, - "till": 1461957424937 - } \ No newline at end of file diff --git a/splitio/tests/splitCustomImpressionListener.json b/splitio/tests/splitCustomImpressionListener.json deleted file mode 100644 index 0dd47956..00000000 --- a/splitio/tests/splitCustomImpressionListener.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "splits": [ - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "iltest", - "seed": -1329591480, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "changeNumber": 1325599980, - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "keySelector": null, - "matcherType": "WHITELIST", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": { - "whitelist": [ - "valid" - ] - }, - "unaryNumericMatcherData": null, - "betweenMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - } - ] - } - ] - } - ], - "since": -1, - "till": 1461957424937 - } \ No newline at end of file diff --git a/splitio/tests/splitGetTreatments.json b/splitio/tests/splitGetTreatments.json deleted file mode 100644 index be637e6c..00000000 --- a/splitio/tests/splitGetTreatments.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "splits": [ - { - "orgId": null, - "environment": null, - "trafficTypeId": null, - "trafficTypeName": null, - "name": "get_treatments_test", - "seed": -1329591480, - "status": "ACTIVE", - "killed": false, - "defaultTreatment": "off", - "changeNumber": 1325599980, - "conditions": [ - { - "matcherGroup": { - "combiner": "AND", - "matchers": [ - { - "keySelector": null, - "matcherType": "WHITELIST", - "negate": false, - "userDefinedSegmentMatcherData": null, - "whitelistMatcherData": { - "whitelist": [ - "valid" - ] - }, - "unaryNumericMatcherData": null, - "betweenMatcherData": null - } - ] - }, - "partitions": [ - { - "treatment": "on", - "size": 100 - } - ] - } - ] - } - ], - "since": -1, - "till": 1461957424937 - } \ No newline at end of file diff --git a/splitio/tests/test_api.py b/splitio/tests/test_api.py deleted file mode 100644 index bb3fd76d..00000000 --- a/splitio/tests/test_api.py +++ /dev/null @@ -1,492 +0,0 @@ -"""Unit tests for the api module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from requests.exceptions import RequestException, HTTPError -from unittest import TestCase - -from splitio.api import (SdkApi, _SEGMENT_CHANGES_URL_TEMPLATE, _SPLIT_CHANGES_URL_TEMPLATE, - _TEST_IMPRESSIONS_URL_TEMPLATE, _METRICS_URL_TEMPLATE) -from splitio.config import SDK_VERSION, SDK_API_BASE_URL, EVENTS_API_BASE_URL -from splitio.tests.utils import MockUtilsMixin - - -class SdkApiBuildHeadersTests(TestCase): - def setUp(self): - super(SdkApiBuildHeadersTests, self).setUp() - - self.some_api_key = 'some_api_key' - - self.api = SdkApi(self.some_api_key) - - def test_always_returns_mandatory_headers(self): - """Tests that the mandatory headers are always included and have the proper values""" - headers = self.api._build_headers() - - self.assertEqual('Bearer some_api_key', headers.get('Authorization')) - self.assertEqual(SDK_VERSION, headers.get('SplitSDKVersion')) - self.assertEqual('gzip', headers.get('Accept-Encoding')) - - def test_optional_headers_not_included_if_not_set(self): - """Tests that the optional headers are not included if they haven't been set""" - headers = self.api._build_headers() - - self.assertNotIn('SplitSDKMachineName', headers) - self.assertNotIn('SplitSDKMachineIP', headers) - - def test_split_sdk_machine_name_included_if_set_as_literal(self): - """Tests that the optional header SplitSDKMachineName is included if set as a literal""" - some_split_sdk_machine_name = mock.NonCallableMagicMock() - self.api._split_sdk_machine_name = some_split_sdk_machine_name - - headers = self.api._build_headers() - - self.assertNotIn(some_split_sdk_machine_name, headers.get('SplitSDKMachineName')) - - def test_split_sdk_machine_name_included_if_set_as_callable(self): - """Tests that the optional header SplitSDKMachineName is included if set as a callable and - its value is the result of calling the function""" - some_split_sdk_machine_name = mock.MagicMock() - self.api._split_sdk_machine_name = some_split_sdk_machine_name - - headers = self.api._build_headers() - - self.assertNotIn(some_split_sdk_machine_name.return_value, - headers.get('SplitSDKMachineName')) - - def test_split_sdk_machine_ip_included_if_set_as_literal(self): - """Tests that the optional header SplitSDKMachineIP is included if set as a literal""" - some_split_sdk_machine_ip = mock.NonCallableMagicMock() - self.api._split_sdk_machine_ip = some_split_sdk_machine_ip - - headers = self.api._build_headers() - - self.assertNotIn(some_split_sdk_machine_ip, headers.get('SplitSDKMachineIP')) - - def test_split_sdk_machine_ip_included_if_set_as_callable(self): - """Tests that the optional header SplitSDKMachineIP is included if set as a callable and - its value is the result of calling the function""" - some_split_sdk_machine_ip = mock.MagicMock() - self.api._split_sdk_machine_ip = some_split_sdk_machine_ip - - headers = self.api._build_headers() - - self.assertNotIn(some_split_sdk_machine_ip.return_value, - headers.get('SplitSDKMachineIP')) - - -class SdkApiGetTests(TestCase, MockUtilsMixin): - def setUp(self): - super(SdkApiGetTests, self).setUp() - - self.requests_get_mock = self.patch('splitio.api.requests').get - - self.some_api_key = mock.MagicMock() - self.some_url = mock.MagicMock() - self.some_params = mock.MagicMock() - - self.api = SdkApi(self.some_api_key) - - self.build_headers_mock = self.patch_object(self.api, '_build_headers') - - def test_proper_headers_are_used(self): - """Tests that the request is made with the proper headers""" - self.api._get(self.some_url, self.some_params) - - self.requests_get_mock.assert_called_once_with(mock.ANY, params=mock.ANY, - headers=self.build_headers_mock.return_value, - timeout=mock.ANY) - - def test_url_parameter_is_used(self): - """Tests that the request is made with the supplied url""" - self.api._get(self.some_url, self.some_params) - - self.requests_get_mock.assert_called_once_with(self.some_url, params=mock.ANY, - headers=mock.ANY, timeout=mock.ANY) - - def test_params_parameter_is_used(self): - """Tests that the request is made with the supplied parameters""" - self.api._get(self.some_url, self.some_params) - - self.requests_get_mock.assert_called_once_with(mock.ANY, params=self.some_params, - headers=mock.ANY, timeout=mock.ANY) - - def test_proper_timeout_is_used(self): - """Tests that the request is made with the proper value for timeout""" - some_timeout = mock.MagicMock() - self.api._timeout = some_timeout - - self.api._get(self.some_url, self.some_params) - - self.requests_get_mock.assert_called_once_with(mock.ANY, params=mock.ANY, headers=mock.ANY, - timeout=some_timeout) - - def test_json_is_returned(self): - """Tests that the function returns the result of calling json() on the requests response""" - result = self.api._get(self.some_url, self.some_params) - - self.assertEqual(self.requests_get_mock.return_value.json.return_value, result) - - def test_request_exceptions_are_raised(self): - """Tests that if requests raises an exception, it is not handled within the call""" - self.requests_get_mock.side_effect = RequestException() - - with self.assertRaises(RequestException): - self.api._get(self.some_url, self.some_params) - - def test_request_status_exceptions_are_not_raised(self): - """Tests that if requests succeeds but its status is not 200 (Ok) an exception is not raised""" - self.requests_get_mock.return_value.raise_for_status.side_effect = HTTPError() - - try: - self.api._get(self.some_url, self.some_params) - except: - self.assertTrue(False) - - def test_json_exceptions_are_raised(self): - """Tests that if requests succeeds but its payload is not JSON, an exception is raised and - it isn't handled within the call""" - self.requests_get_mock.return_value.json.side_effect = ValueError() - - with self.assertRaises(ValueError): - self.api._get(self.some_url, self.some_params) - - -class SdkApiPostTests(TestCase, MockUtilsMixin): - def setUp(self): - super(SdkApiPostTests, self).setUp() - - self.requests_post_mock = self.patch('splitio.api.requests').post - - self.some_api_key = mock.MagicMock() - self.some_url = mock.MagicMock() - self.some_data = mock.MagicMock() - - self.api = SdkApi(self.some_api_key) - - self.build_headers_mock = self.patch_object(self.api, '_build_headers') - - def test_proper_headers_are_used(self): - """Tests that the request is made with the proper headers""" - self.api._post(self.some_url, self.some_data) - - self.requests_post_mock.assert_called_once_with( - mock.ANY, json=mock.ANY, headers=self.build_headers_mock.return_value, timeout=mock.ANY) - - def test_url_parameter_is_used(self): - """Tests that the request is made with the supplied url""" - self.api._post(self.some_url, self.some_data) - - self.requests_post_mock.assert_called_once_with(self.some_url, json=mock.ANY, - headers=mock.ANY, timeout=mock.ANY) - - def test_data_parameter_is_used(self): - """Tests that the request is made with the supplied data as json parameter""" - self.api._post(self.some_url, self.some_data) - - self.requests_post_mock.assert_called_once_with(mock.ANY, json=self.some_data, - headers=mock.ANY, timeout=mock.ANY) - - def test_proper_timeout_is_used(self): - """Tests that the request is made with the proper value for timeout""" - some_timeout = mock.MagicMock() - self.api._timeout = some_timeout - - self.api._post(self.some_url, self.some_data) - - self.requests_post_mock.assert_called_once_with(mock.ANY, json=mock.ANY, headers=mock.ANY, - timeout=some_timeout) - - def test_status_is_returned(self): - """Tests that the function returns the the status code of the response""" - result = self.api._post(self.some_url, self.some_data) - - self.assertEqual(self.requests_post_mock.return_value.status_code, result) - - def test_request_exceptions_are_raised(self): - """Tests that if requests raises an exception, it is not handled within the call""" - self.requests_post_mock.side_effect = RequestException() - - with self.assertRaises(RequestException): - self.api._post(self.some_url, self.some_data) - - def test_request_status_exceptions_are_not_raised(self): - """Tests that if requests succeeds but its status is not 200 (Ok) an exception is not raised""" - self.requests_post_mock.return_value.raise_for_status.side_effect = HTTPError() - - try: - self.api._post(self.some_url, self.some_data) - except: - self.assertTrue(False) - - -class SdkApiSplitChangesTest(TestCase, MockUtilsMixin): - def setUp(self): - super(SdkApiSplitChangesTest, self).setUp() - - self.some_api_key = mock.MagicMock() - self.some_since = mock.MagicMock() - - self.api = SdkApi(self.some_api_key) - - self.get_mock = self.patch_object(self.api, '_get') - - def test_default_split_changes_url_is_used(self): - """Tests that the default split changes endpoint url is used if sdk_api_base_url hasn't - been set""" - self.api.split_changes(self.some_since) - - expected_url = _SPLIT_CHANGES_URL_TEMPLATE.format( - base_url=SDK_API_BASE_URL - ) - - self.get_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_rebased_split_changes_url_is_used(self): - """Tests that if sdk_api_base_url has been set, it is used as the base for the url of the - request""" - - some_sdk_api_url_base = 'some_sdk_api_url_base' - self.api._sdk_api_url_base = some_sdk_api_url_base - self.api.split_changes(self.some_since) - - expected_url = _SPLIT_CHANGES_URL_TEMPLATE.format( - base_url=some_sdk_api_url_base - ) - - self.get_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_proper_params_are_used(self): - """Tests that the request to the split changes endpoint is made with the proper - parameters""" - self.api.split_changes(self.some_since) - - self.get_mock.assert_called_once_with(mock.ANY, {'since': self.some_since}) - - def test_exceptions_from_get_are_raised(self): - """Tests that any exceptions raised from calling _get are not handled with the call""" - self.get_mock.side_effect = Exception() - - with self.assertRaises(Exception): - self.api.split_changes(self.some_since) - - def test_returns_get_result(self): - """Tests that the method returns the result of calling get""" - self.assertEqual(self.get_mock.return_value, self.api.split_changes(self.some_since)) - - -class SdkApiSegmentChangesTests(TestCase, MockUtilsMixin): - def setUp(self): - super(SdkApiSegmentChangesTests, self).setUp() - - self.some_api_key = mock.MagicMock() - self.some_name = 'some_name' - self.some_since = mock.MagicMock() - - self.api = SdkApi(self.some_api_key) - - self.get_mock = self.patch_object(self.api, '_get') - - def test_default_segment_changes_url_is_used(self): - """Tests that the default segment changes endpoint url is used if sdk_api_base_url hasn't - been set""" - self.api.segment_changes(self.some_name, self.some_since) - - expected_url = _SEGMENT_CHANGES_URL_TEMPLATE.format( - base_url=SDK_API_BASE_URL, - segment_name=self.some_name - ) - - self.get_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_rebased_segment_changes_url_is_used(self): - """Tests that if sdk_api_base_url has been set, it is used as the base for the url of the - request""" - - some_sdk_api_url_base = 'some_sdk_api_url_base' - self.api._sdk_api_url_base = some_sdk_api_url_base - self.api.segment_changes(self.some_name, self.some_since) - - expected_url = _SEGMENT_CHANGES_URL_TEMPLATE.format( - base_url=some_sdk_api_url_base, - segment_name=self.some_name - ) - - self.get_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_proper_params_are_used(self): - """Tests that the request to the segment changes endpoint is made with the proper - parameters""" - self.api.segment_changes(self.some_name, self.some_since) - - self.get_mock.assert_called_once_with(mock.ANY, {'since': self.some_since}) - - def test_exceptions_from_get_are_raised(self): - """Tests that any exceptions raised from calling _get are not handled with the call""" - self.get_mock.side_effect = Exception() - - with self.assertRaises(Exception): - self.api.segment_changes(self.some_name, self.some_since) - - def test_returns_get_result(self): - """Tests that the method returns the result of calling get""" - self.assertEqual(self.get_mock.return_value, self.api.segment_changes(self.some_name, - self.some_since)) - - -class SdkApiTestImpressionsTest(TestCase, MockUtilsMixin): - def setUp(self): - super(SdkApiTestImpressionsTest, self).setUp() - - self.some_api_key = mock.MagicMock() - self.some_test_impressions_data = mock.MagicMock() - - self.api = SdkApi(self.some_api_key) - - self.post_mock = self.patch_object(self.api, '_post') - - def test_default_test_impressions_url_is_used(self): - """Tests that the default test impressions endpoint url is used if sdk_api_base_url hasn't - been set""" - self.api.test_impressions(self.some_test_impressions_data) - - expected_url = _TEST_IMPRESSIONS_URL_TEMPLATE.format( - base_url=EVENTS_API_BASE_URL - ) - - self.post_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_rebased_test_impressions_url_is_used(self): - """Tests that if sdk_api_base_url has been set, it is used as the base for the url of the - request""" - - some_events_api_url_base = 'some_events_api_url_base' - self.api._events_api_url_base = some_events_api_url_base - self.api.test_impressions(self.some_test_impressions_data) - - expected_url = _TEST_IMPRESSIONS_URL_TEMPLATE.format( - base_url=some_events_api_url_base - ) - - self.post_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_proper_params_are_used(self): - """Tests that the request to the test impressions endpoint is made with the proper - parameters""" - self.api.test_impressions(self.some_test_impressions_data) - - self.post_mock.assert_called_once_with(mock.ANY, self.some_test_impressions_data) - - def test_exceptions_from_get_are_raised(self): - """Tests that any exceptions raised from calling _post are not handled with the call""" - self.post_mock.side_effect = Exception() - - with self.assertRaises(Exception): - self.api.test_impressions(self.some_test_impressions_data) - - def test_returns_post_result(self): - """Tests that the method returns the result of calling post""" - self.assertEqual(self.post_mock.return_value, - self.api.test_impressions(self.some_test_impressions_data)) - - -class SdkApiMetricsTest(MockUtilsMixin): - def setUp(self): - super(SdkApiMetricsTest, self).setUp() - - self.some_api_key = mock.MagicMock() - self.some_data = mock.MagicMock() - - self.api = SdkApi(self.some_api_key) - - self.post_mock = self.patch_object(self.api, '_post') - - def _get_endpoint(self): - raise NotImplementedError() - - def _call_method(self, *args, **kwargs): - raise NotImplementedError() - - def test_default_metrics_url_is_used(self): - """Tests that the default metrics endpoint url is used if sdk_api_base_url hasn't been - set""" - self._call_method(self.some_data) - - expected_url = _METRICS_URL_TEMPLATE.format( - base_url=EVENTS_API_BASE_URL, - endpoint=self._get_endpoint() - ) - - self.post_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_rebased_test_impressions_url_is_used(self): - """Tests that if sdk_api_base_url has been set, it is used as the base for the url of the - request""" - - some_events_api_url_base = 'some_events_api_url_base' - self.api._events_api_url_base = some_events_api_url_base - self._call_method(self.some_data) - - expected_url = _METRICS_URL_TEMPLATE.format( - base_url=some_events_api_url_base, - endpoint=self._get_endpoint() - ) - - self.post_mock.assert_called_once_with(expected_url, mock.ANY) - - def test_proper_params_are_used(self): - """Tests that the request to the metrics times endpoint is made with the proper - parameters""" - self._call_method(self.some_data) - - self.post_mock.assert_called_once_with(mock.ANY, self.some_data) - - def test_exceptions_from_get_are_raised(self): - """Tests that any exceptions raised from calling _post are not handled with the call""" - self.post_mock.side_effect = Exception() - - with self.assertRaises(Exception): - self._call_method(self.some_data) - - def test_returns_get_result(self): - """Tests that the method returns the result of calling post""" - self.assertEqual(self.post_mock.return_value, - self._call_method(self.some_data)) - - -class SdkApiMetricsTimesTest(SdkApiMetricsTest, TestCase): - def setUp(self): - super(SdkApiMetricsTimesTest, self).setUp() - - def _get_endpoint(self): - return 'times' - - def _call_method(self, *args, **kwargs): - return self.api.metrics_times(*args, **kwargs) - - -class SdkApiMetricsCountersTest(SdkApiMetricsTest, TestCase): - def setUp(self): - super(SdkApiMetricsCountersTest, self).setUp() - - def _get_endpoint(self): - return 'counters' - - def _call_method(self, *args, **kwargs): - return self.api.metrics_counters(*args, **kwargs) - - -class SdkApiMetricsGaugeTest(SdkApiMetricsTest, TestCase): - def setUp(self): - super(SdkApiMetricsGaugeTest, self).setUp() - - def _get_endpoint(self): - return 'gauge' - - def _call_method(self, *args, **kwargs): - return self.api.metrics_gauge(*args, **kwargs) diff --git a/splitio/tests/test_cache.py b/splitio/tests/test_cache.py deleted file mode 100644 index 808e5d2e..00000000 --- a/splitio/tests/test_cache.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Unit tests for the cache module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase - -from splitio.cache import (InMemorySplitCache, InMemorySegmentCache, InMemoryImpressionsCache) -from splitio.tests.utils import MockUtilsMixin - - -class InMemorySegmentCacheTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_segment_name = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.segment_cache = InMemorySegmentCache() - self.entries_mock = self.patch_object(self.segment_cache, '_entries') - self.existing_entries = {'key_set': {'some_key_1', 'some_key_2'}, 'change_number': -1} - self.entries_mock.__getitem__.return_value = self.existing_entries - - def test_add_keys_to_segment_sets_keys_to_union(self): - """Tests that add_keys_to_segment sets keys to union of old and new keys""" - self.segment_cache.add_keys_to_segment(self.some_segment_name, {'some_key_2', 'some_key_3'}) - self.assertSetEqual({'some_key_1', 'some_key_2', 'some_key_3'}, - self.existing_entries['key_set']) - - def test_remove_keys_from_segment_set_keys_to_difference(self): - """Tests that remove_from_segment sets keys to difference of old and new keys""" - self.segment_cache.remove_keys_from_segment(self.some_segment_name, - {'some_key_2', 'some_key_3'}) - self.assertSetEqual({'some_key_1'}, self.existing_entries['key_set']) - - def test_is_in_segment_calls_in_on_entries(self): - """Tests that is_in_segment checks if key in internal set""" - self.assertTrue(self.segment_cache.is_in_segment(self.some_segment_name, 'some_key_1')) - self.assertFalse(self.segment_cache.is_in_segment(self.some_segment_name, 'some_key_3')) - - def test_set_change_number_sets_change_number_for_segment(self): - """Tests that set_change_number sets the change number for the segment""" - self.segment_cache.set_change_number(self.some_segment_name, self.some_change_number) - self.assertIn('change_number', self.existing_entries) - self.assertEqual(self.some_change_number, self.existing_entries['change_number']) - - def test_get_change_number_returns_existing_change_number(self): - """Tests that get_change_number resturns the current change number for the segment""" - self.assertEqual(self.existing_entries['change_number'], - self.segment_cache.get_change_number(self.some_segment_name)) - - -class InMemorySplitCacheTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_split_name = mock.MagicMock() - self.some_split = mock.MagicMock() - self.split_cache = InMemorySplitCache() - self.some_change_number = mock.MagicMock() - self.entries_mock = self.patch_object(self.split_cache, '_entries') - - def test_add_split_calls_entries_setitem(self): - """Tests that add_split calls __setitem__ on entries""" - self.split_cache.add_split(self.some_split_name, self.some_split) - self.entries_mock.__setitem__.assert_called_once_with(self.some_split_name, - self.some_split) - - def test_remove_split_calls_entries_pop(self): - """Tests that remove_split calls pop on entries""" - self.split_cache.remove_split(self.some_split_name) - self.entries_mock.pop.assert_called_once_with(self.some_split_name, None) - - def test_get_split_calls_get(self): - """Tests that get_split calls get on entries""" - self.split_cache.get_split(self.some_split_name) - self.entries_mock.get.assert_called_once_with(self.some_split_name, None) - - def test_get_split_returns_get_result(self): - """Tests that get_split returns the result of calling get on entries""" - self.assertEqual(self.entries_mock.get.return_value, - self.split_cache.get_split(self.some_split_name)) - - def test_set_change_number_sets_change_number(self): - """Test that set_change_number sets the change number""" - self.split_cache.set_change_number(self.some_change_number) - self.assertEqual(self.some_change_number, self.split_cache._change_number) - - def test_get_change_number_returns_change_number(self): - """Test that get_change_number returns the change number""" - self.split_cache.set_change_number(self.some_change_number) - self.split_cache._change_number = self.some_change_number - self.assertEqual(self.some_change_number, self.split_cache.get_change_number()) - - -class InMemoryImpressionsCacheInitTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_impressions = mock.MagicMock() - self.impressions_mock = mock.MagicMock() - self.defaultdict_mock = self.patch('splitio.cache.defaultdict', - return_value=self.impressions_mock) - - def test_init_initializes_impressions_with_impressions_parameter(self): - """Tests that __init__ updates the _impressions field with the impressions parameter""" - InMemoryImpressionsCache(impressions=self.some_impressions) - self.impressions_mock.update.assert_called_once_with(self.some_impressions) - - -class InMemoryImpressionsCacheTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_impression = mock.MagicMock() - self.rlock_mock = self.patch('splitio.cache.RLock') - self.deepcopy_mock = self.patch('splitio.cache.deepcopy') - self.impressions_mock = mock.MagicMock() - self.defaultdict_mock = self.patch('splitio.cache.defaultdict', - return_value=self.impressions_mock) - self.impressions_cache = InMemoryImpressionsCache() - self.defaultdict_mock.reset_mock() - - def test_add_impression_appends_impression(self): - """Test that add_impression appends impression to existing list""" - self.impressions_cache.add_impression(self.some_impression) - self.impressions_mock.__getitem__.assert_called_once_with(self.some_impression.feature) - self.impressions_mock.__getitem__.return_value.append.assert_called_once_with( - self.some_impression) - - def test_fetch_all_returns_impressions_copy(self): - """Test that fetch all returns a copy of the impressions""" - result = self.impressions_cache.fetch_all() - self.deepcopy_mock.assert_called_once_with(self.impressions_mock) - self.assertEqual(self.deepcopy_mock.return_value, result) - - def test_clear_resets_impressions(self): - """Test that clear resets impressions""" - self.impressions_cache.clear() - self.assertEqual(self.defaultdict_mock.return_value, self.impressions_cache._impressions) - - def test_fetch_all_and_clear_returns_impressions_copy(self): - """Test that fetch_all_and_clear returns impressions copy""" - result = self.impressions_cache.fetch_all_and_clear() - self.deepcopy_mock.assert_called_once_with(self.impressions_mock) - self.assertEqual(self.deepcopy_mock.return_value, result) - - def test_fetch_all_and_clear_clears_impressions(self): - """Test that fetch_all_and_clear clears impressions""" - result = self.impressions_cache.fetch_all_and_clear() - self.deepcopy_mock.assert_called_once_with(self.impressions_mock) - self.assertEqual(self.deepcopy_mock.return_value, result) diff --git a/splitio/tests/test_clients.py b/splitio/tests/test_clients.py deleted file mode 100644 index 5342fa9e..00000000 --- a/splitio/tests/test_clients.py +++ /dev/null @@ -1,1284 +0,0 @@ -"""Unit tests for the matchers module""" -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -import tempfile -import arrow -import os.path - -from unittest import TestCase -from time import sleep - -from splitio import get_factory -from splitio.clients import Client -from splitio.brokers import JSONFileBroker, RedisBroker, LocalhostBroker, \ - UWSGIBroker, randomize_interval, SelfRefreshingBroker -from splitio.exceptions import TimeoutException -from splitio.config import DEFAULT_CONFIG, MAX_INTERVAL, SDK_API_BASE_URL, \ - EVENTS_API_BASE_URL -from splitio.treatments import CONTROL -from splitio.tests.utils import MockUtilsMixin -from splitio.managers import SelfRefreshingSplitManager, UWSGISplitManager, RedisSplitManager - - -class RandomizeIntervalTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_value = mock.MagicMock() - self.max_mock = self.patch_builtin('max') - self.randint_mock = self.patch('splitio.brokers.random.randint') - - def test_returns_callable(self): - """ - Tests that randomize_interval returns a callable - """ - - self.assertTrue(hasattr(randomize_interval(self.some_value), '__call__')) - - def test_returned_function_calls_randint(self): - """ - Tests that the function returned by randomize_interval calls randint with the proper - parameters - """ - randomize_interval(self.some_value)() - self.some_value.__floordiv__.assert_called_once_with(2) - self.randint_mock.assert_called_once_with(self.some_value.__floordiv__.return_value, - self.some_value) - - def test_returned_function_calls_max(self): - """ - Tests that the function returned by randomize_interval calls max with the proper - parameters - """ - randomize_interval(self.some_value)() - self.max_mock.assert_called_once_with(5, self.randint_mock.return_value) - - def test_returned_function_returns_max_result(self): - """ - Tests that the function returned by randomize_interval returns the result of calling max - """ - self.assertEqual(self.max_mock.return_value, randomize_interval(self.some_value)()) - - -class SelfRefreshingBrokerInitTests(TestCase, MockUtilsMixin): - - def setUp(self): - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - self.start_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._start') - - self.some_api_key = mock.MagicMock() - self.some_config = mock.MagicMock() - - def test_sets_api_key(self): - """Test that __init__ sets api key to the given value""" - broker = SelfRefreshingBroker(self.some_api_key) - self.assertEqual(self.some_api_key, broker._api_key) - - def test_calls_build_sdk_api(self): - """Test that __init__ calls _build_sdk_api""" - client = SelfRefreshingBroker(self.some_api_key) - self.build_sdk_api_mock.assert_called_once_with() - self.assertEqual(self.build_sdk_api_mock.return_value, client._sdk_api) - - def test_calls_build_split_fetcher(self): - """Test that __init__ calls _build_split_fetcher""" - client = SelfRefreshingBroker(self.some_api_key) - self.build_split_fetcher_mock.assert_called_once_with() - self.assertEqual(self.build_split_fetcher_mock.return_value, client._split_fetcher) - - def test_calls_build_build_treatment_log(self): - """Test that __init__ calls _build_treatment_log""" - client = SelfRefreshingBroker(self.some_api_key) - self.build_treatment_log_mock.assert_called_once_with() - self.assertEqual(self.build_treatment_log_mock.return_value, client._treatment_log) - - def test_calls_build_treatment_log(self): - """Test that __init__ calls _build_treatment_log""" - client = SelfRefreshingBroker(self.some_api_key) - self.build_treatment_log_mock.assert_called_once_with() - self.assertEqual(self.build_treatment_log_mock.return_value, client._treatment_log) - - def test_calls_build_metrics(self): - """Test that __init__ calls _build_metrics""" - client = SelfRefreshingBroker(self.some_api_key) - self.build_metrics_mock.assert_called_once_with() - self.assertEqual(self.build_metrics_mock.return_value, client._metrics) - - def test_calls_start(self): - """Test that __init__ calls _start""" - SelfRefreshingBroker(self.some_api_key) - self.start_mock.assert_called_once_with() - - -class SelfRefreshingBrokerStartTests(TestCase, MockUtilsMixin): - def setUp(self): - self.event_mock = self.patch('splitio.brokers.threading.Event') - self.event_mock.return_value.wait.return_value = True - self.thread_mock = self.patch('splitio.brokers.threading.Thread') - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - self.fetch_splits_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._fetch_splits') - - self.some_api_key = mock.MagicMock() - - def test_calls_start_on_treatment_log_delegate(self): - """Test that _start calls start on the treatment log delegate""" - SelfRefreshingBroker(self.some_api_key, config={'ready': 0}) - self.build_treatment_log_mock.return_value.delegate.start.assert_called_once_with() - - def test_calls_start_on_treatment_log_delegate_with_timeout(self): - """Test that _start calls start on the treatment log delegate when a timeout is given""" - SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) - self.build_treatment_log_mock.return_value.delegate.start.assert_called_once_with() - -# TODO: Remove This! This test is no longer value for the new asynctasks introduced. -# . When all tasks are migrated to the new model, this should be removed -# def test_no_event_or_thread_created_if_timeout_is_zero(self): -# """Test that if timeout is zero, no threads or events are created""" -# SelfRefreshingBroker(self.some_api_key, config={'ready': 0}) -# self.event_mock.assert_not_called() -# self.thread_mock.assert_not_called() - - def test_split_fetcher_start_called_if_timeout_is_zero(self): - """Test that if timeout is zero, start is called on the split fetcher""" - SelfRefreshingBroker(self.some_api_key, config={'ready': 0}) - self.build_split_fetcher_mock.assert_called_once_with() - - def test_event_created_if_timeout_is_non_zero(self): - """Test that if timeout is non-zero, an event is created""" - SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) - self.event_mock.assert_called_once_with() - - def test_wait_is_called_on_event_if_timeout_is_non_zero(self): - """Test that if timeout is non-zero, wait is called on the event""" - SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) - self.event_mock.return_value.wait.asser_called_once_with(10) - -# TODO: Remove This! This test is no longer value for the new asynctasks introduced. -# When all tasks are migrated to the new model, this should be removed -# def test_thread_created_if_timeout_is_non_zero(self): -# """Test that if timeout is non-zero, a thread with target _fetch_splits is created""" -# SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) -# self.thread_mock.assert_called_once_with(target=self.fetch_splits_mock, -# args=(self.event_mock.return_value,)) -# self.thread_mock.return_value.start.asser_called_once_with() - - def test_if_event_flag_is_not_set_an_exception_is_raised(self): - """Test that if the event flag is not set, a TimeoutException is raised""" - self.event_mock.return_value.wait.return_value = False - with self.assertRaises(TimeoutException): - SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) - - def test_if_event_flag_is_set_an_exception_is_not_raised(self): - """Test that if the event flag is set, a TimeoutException is not raised""" - try: - SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) - except Exception: - self.fail('An unexpected exception was raised') - - -class SelfRefreshingBrokerFetchSplitsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_event = mock.MagicMock() - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - - self.some_api_key = mock.MagicMock() - self.client = SelfRefreshingBroker(self.some_api_key, config={'ready': 10}) - self.build_split_fetcher_mock.reset_mock() - - def test_calls_refresh_splits_on_split_fetcher(self): - """Test that _fetch_splits calls refresh_splits on split_fetcher""" - self.client._fetch_splits(self.some_event) - self.build_split_fetcher_mock.return_value.refresh_splits.assert_called_once_with( - block_until_ready=True) - - def test_calls_start_on_split_fetcher(self): - """Test that _fetch_splits calls start on split_fetcher""" - self.client._fetch_splits(self.some_event) - self.build_split_fetcher_mock.return_value.start.assert_called_once_with( - delayed_update=True) - - def test_calls_set_on_event(self): - """Test that _fetch_splits calls set on event""" - self.client._fetch_splits(self.some_event) - self.some_event.set.assert_called_once_with() - - -class SelfRefreshingBrokerInitConfigTests(TestCase, MockUtilsMixin): - def setUp(self): - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - self.start_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._start') - self.some_api_key = mock.MagicMock() - self.randomize_interval_side_effect = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - self.randomize_interval_mock = self.patch( - 'splitio.brokers.randomize_interval', side_effect=self.randomize_interval_side_effect) - - self.some_config = { - 'connectionTimeout': mock.MagicMock(), - 'readTimeout': mock.MagicMock(), - 'featuresRefreshRate': 31, - 'segmentsRefreshRate': 32, - 'metricsRefreshRate': 33, - 'impressionsRefreshRate': 34, - 'randomizeIntervals': False, - 'maxImpressionsLogSize': -1, - 'maxMetricsCallsBeforeFlush': -1, - 'ready': 10, - 'sdkApiBaseUrl': SDK_API_BASE_URL, - 'eventsApiBaseUrl': EVENTS_API_BASE_URL, - 'splitSdkMachineName': None, - 'splitSdkMachineIp': None, - 'redisHost': 'localhost', - 'redisPort': 6379, - 'redisDb': 0, - 'redisPassword': None, - 'redisSocketTimeout': None, - 'redisSocketConnectTimeout': None, - 'redisSocketKeepalive': None, - 'redisSocketKeepaliveOptions': None, - 'redisConnectionPool': None, - 'redisUnixSocketPath': None, - 'redisEncoding': 'utf-8', - 'redisEncodingErrors': 'strict', - 'redisCharset': None, - 'redisErrors': None, - 'redisDecodeResponses': False, - 'redisRetryOnTimeout': False, - 'redisSsl': False, - 'redisSslKeyfile': None, - 'redisSslCertfile': None, - 'redisSslCertReqs': None, - 'redisSslCaCerts': None, - 'redisMaxConnections': None, - 'eventsPushRate': 60, - 'eventsQueueSize': 500, - } - - self.client = SelfRefreshingBroker(self.some_api_key) - - def test_if_config_is_none_uses_default(self): - """Test that if config is None _init_config uses the defaults""" - self.client._init_config(config=None) - self.assertDictEqual(DEFAULT_CONFIG, self.client._config) - - def test_it_uses_supplied_config(self): - """Test that if config is not None, it uses the supplied config""" - self.client._init_config(config=self.some_config) - - print('!1', self.some_config) - print('!2', self.client._config) - - self.assertDictEqual(self.some_config, self.client._config) - - def test_forces_interval_max_on_intervals(self): - """ - Tests that __init__ forces default maximum on intervals - """ - self.some_config.update({ - 'featuresRefreshRate': MAX_INTERVAL + 10, - 'segmentsRefreshRate': MAX_INTERVAL + 20, - 'metricsRefreshRate': MAX_INTERVAL + 30, - 'impressionsRefreshRate': MAX_INTERVAL + 40 - }) - self.client._init_config(config=self.some_config) - self.assertEqual(MAX_INTERVAL, self.client._split_fetcher_interval) - self.assertEqual(MAX_INTERVAL, self.client._segment_fetcher_interval) - self.assertEqual(MAX_INTERVAL, self.client._impressions_interval) - - def test_randomizes_intervales_if_randomize_intervals_is_true(self): - """ - Tests that __init__ calls randomize_interval on intervals if randomizeIntervals is True - """ - self.some_config['randomizeIntervals'] = True - self.client._init_config(config=self.some_config) - self.assertListEqual([mock.call(self.some_config['segmentsRefreshRate']), - mock.call(self.some_config['featuresRefreshRate']), - mock.call(self.some_config['impressionsRefreshRate'])], - self.randomize_interval_mock.call_args_list) - self.assertEqual(self.randomize_interval_side_effect[0], - self.client._segment_fetcher_interval) - self.assertEqual(self.randomize_interval_side_effect[1], - self.client._split_fetcher_interval) - self.assertEqual(self.randomize_interval_side_effect[2], - self.client._impressions_interval) - - -class SelfRefreshingBrokerBuildSdkApiTests(TestCase, MockUtilsMixin): - def setUp(self): - self.sdk_api_mock = self.patch('splitio.brokers.SdkApi') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - self.start_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._start') - self.some_api_key = mock.MagicMock() - self.client = SelfRefreshingBroker(self.some_api_key) - - def test_calls_sdk_api_constructor(self): - """Test that _build_sdk_api calls SdkApi constructor""" - self.sdk_api_mock.assert_called_once_with( - self.some_api_key, sdk_api_base_url=self.client._sdk_api_base_url, - events_api_base_url=self.client._events_api_base_url, - connect_timeout=self.client._connection_timeout, read_timeout=self.client._read_timeout - ) - - -class SelfRefreshingBrokerBuildSplitFetcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - self.start_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._start') - self.some_api_key = mock.MagicMock() - - self.api_segment_change_fetcher_mock = self.patch('splitio.brokers.ApiSegmentChangeFetcher') - self.self_refreshing_segment_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingSegmentFetcher') - self.api_split_change_fetcher_mock = self.patch('splitio.brokers.ApiSplitChangeFetcher') - self.split_parser_mock = self.patch('splitio.brokers.SplitParser') - self.self_refreshing_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingSplitFetcher') - - self.some_api_key = mock.MagicMock() - self.client = SelfRefreshingBroker(self.some_api_key) - - def test_builds_segment_change_fetcher(self): - """Tests that _build_split_fetcher calls the ApiSegmentChangeFetcher constructor""" - self.api_segment_change_fetcher_mock.assert_called_once_with( - self.build_sdk_api_mock.return_value) - - def test_builds_segment_fetcher(self): - """Tests that _build_split_fetcher calls the SelfRefreshingSegmentFetcher constructor""" - self.self_refreshing_segment_fetcher_mock.assert_called_once_with( - self.api_segment_change_fetcher_mock.return_value, - interval=self.client._segment_fetcher_interval) - - def test_builds_split_change_fetcher(self): - """Tests that _build_split_fetcher calls the ApiSplitChangeFetcher constructor""" - self.api_split_change_fetcher_mock.assert_called_once_with( - self.build_sdk_api_mock.return_value) - - def test_builds_split_parser(self): - """Tests that _build_split_fetcher calls the SplitParser constructor""" - self.split_parser_mock.assert_called_once_with( - self.self_refreshing_segment_fetcher_mock.return_value) - - def test_builds_split_fetcher(self): - """Tests that _build_split_fetcher calls the SplitParser constructor""" - self.self_refreshing_split_fetcher_mock.assert_called_once_with( - self.api_split_change_fetcher_mock.return_value, self.split_parser_mock.return_value, - interval=self.client._split_fetcher_interval) - - def test_returns_split_fetcher(self): - """Tests that _build_split_fetcher returns the result of calling the - SelfRefreshingSplitFetcher constructor""" - self.assertEqual(self.self_refreshing_split_fetcher_mock.return_value, - self.client._build_split_fetcher()) - - -class SelfRefreshingBrokerBuildTreatmentLogTests(TestCase, MockUtilsMixin): - def setUp(self): - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_metrics_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_metrics') - self.start_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._start') - self.some_api_key = mock.MagicMock() - - self.self_updating_treatment_log_mock = self.patch( - 'splitio.brokers.SelfUpdatingTreatmentLog') - self.aync_treatment_log_mock = self.patch( - 'splitio.brokers.AsyncTreatmentLog') - self.some_api_key = mock.MagicMock() - self.client = SelfRefreshingBroker(self.some_api_key) - - def test_calls_self_updating_treatment_log_constructor(self): - """Tests that _build_treatment_log calls SelfUpdatingTreatmentLog constructor""" - self.self_updating_treatment_log_mock.assert_called_once_with( - self.client._sdk_api, - max_count=self.client._max_impressions_log_size, - interval=self.client._impressions_interval - ) - - def test_calls_async_treatment_log_constructor(self): - """Tests that _build_treatment_log calls AsyncTreatmentLog constructor""" - self.aync_treatment_log_mock.assert_called_once_with( - self.self_updating_treatment_log_mock.return_value) - - def test_returns_async_treatment_log(self): - """Tests that _build_treatment_log returns an AsyncTreatmentLog""" - self.assertEqual(self.aync_treatment_log_mock.return_value, - self.client._build_treatment_log()) - - -class SelfRefreshingBrokerBuildMetricsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.build_sdk_api_mock = self.patch('splitio.brokers.SelfRefreshingBroker._build_sdk_api') - self.build_split_fetcher_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_split_fetcher') - self.build_treatment_log_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._build_treatment_log') - self.start_mock = self.patch( - 'splitio.brokers.SelfRefreshingBroker._start') - self.some_api_key = mock.MagicMock() - - self.api_metrics_mock = self.patch( - 'splitio.brokers.ApiMetrics') - self.aync_metrics_mock = self.patch( - 'splitio.brokers.AsyncMetrics') - self.some_api_key = mock.MagicMock() - self.client = SelfRefreshingBroker(self.some_api_key) - - def test_calls_api_metrics_constructor(self): - """Tests that _build_metrics calls ApiMetrics constructor""" - self.api_metrics_mock.assert_called_once_with( - self.client._sdk_api, max_call_count=self.client._metrics_max_call_count, - max_time_between_calls=self.client._metrics_max_time_between_calls) - - def test_calls_async_metrics_constructor(self): - """Tests that _build_metrics calls AsyncMetrics constructor""" - self.aync_metrics_mock.assert_called_once_with( - self.api_metrics_mock.return_value) - - def test_returns_async_treatment_log(self): - """Tests that _build_metrics returns an AsyncMetrics""" - self.assertEqual(self.aync_metrics_mock.return_value, self.client._build_metrics()) - - -class JSONFileBrokerIntegrationTests(TestCase): - @classmethod - def setUpClass(cls): - cls.some_config = mock.MagicMock() - cls.segment_changes_file_name = os.path.join( - os.path.dirname(__file__), - 'segmentChanges.json' - ) - cls.split_changes_file_name = os.path.join( - os.path.dirname(__file__), - 'splitChanges.json' - ) - cls.client = Client(JSONFileBroker(cls.some_config, cls.segment_changes_file_name, - cls.split_changes_file_name)) - cls.on_treatment = 'on' - cls.off_treatment = 'off' - cls.some_key = 'some_key' - cls.fake_id_in_segment = 'fake_id_1' - cls.fake_id_not_in_segment = 'foobar' - cls.fake_id_on_key = 'fake_id_on' - cls.fake_id_off_key = 'fake_id_off' - cls.fake_id_some_treatment_key = 'fake_id_some_treatment' - cls.attribute_name = 'some_attribute' - cls.unknown_feature_name = 'foobar' - cls.in_between_datetime = arrow.get(2016, 4, 25, 16, 0).timestamp - cls.not_in_between_datetime = arrow.get(2015, 4, 25, 16, 0).timestamp - cls.in_between_number = 42 - cls.not_in_between_number = 85 - cls.equal_to_datetime = arrow.get(2016, 4, 25, 16, 0).timestamp - cls.not_equal_to_datetime = arrow.get(2015, 4, 25, 16, 0).timestamp - cls.equal_to_number = 50 - cls.not_equal_to_number = 85 - cls.greater_than_or_equal_to_datetime = arrow.get(2016, 4, 25, 16, 0).timestamp - cls.not_greater_than_or_equal_to_datetime = arrow.get(2015, 4, 25, 16, 0).timestamp - cls.greater_than_or_equal_to_number = 50 - cls.not_greater_than_or_equal_to_number = 32 - cls.less_than_or_equal_to_datetime = arrow.get(2015, 4, 25, 16, 0).timestamp - cls.not_less_than_or_equal_to_datetime = arrow.get(2016, 4, 25, 16, 0).timestamp - cls.less_than_or_equal_to_number = 32 - cls.not_less_than_or_equal_to_number = 50 - cls.multi_condition_equal_to_number = 42 - cls.multi_condition_not_equal_to_number = 85 - cls.in_whitelist = 'bitsy' - cls.not_in_whitelist = 'foobar' - - # - # basic tests - # - - def test_no_key_returns_control(self): - """ - Tests that get_treatment returns control treatment if the key is None - """ - self.assertEqual(CONTROL, self.client.get_treatment( - None, 'test_in_segment')) - - def test_unknown_feature_returns_control(self): - """ - Tests that get_treatment returns control treatment if feature is unknown - """ - self.assertEqual(CONTROL, self.client.get_treatment( - self.some_key, self.unknown_feature_name)) - - # - # test_between_datetime tests - # - - def test_test_between_datetime_include_on_user(self): - """ - Test that get_treatment returns on for the test_between_datetime feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_between_datetime', - {self.attribute_name: self.in_between_datetime})) - - def test_test_between_datetime_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_between_datetime feature using the user key - included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_between_datetime', - {self.attribute_name: self.not_in_between_datetime})) - - def test_test_between_datetime_include_off_user(self): - """ - Test that get_treatment returns off for the test_between_datetime feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_between_datetime', - {self.attribute_name: self.in_between_datetime})) - - def test_test_between_datetime_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_between_datetime feature using the some key - while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_between_datetime', - {self.attribute_name: self.in_between_datetime})) - - def test_test_between_datetime_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_between_datetime feature using the some key - while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_between_datetime', - {self.attribute_name: self.not_in_between_datetime})) - - def test_test_between_datetime_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_between_datetime feature using the some key - and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_between_datetime')) - - # - # test_between_number tests - # - - def test_test_between_number_include_on_user(self): - """ - Test that get_treatment returns on for the test_between_number feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_between_number', - {self.attribute_name: self.in_between_number})) - - def test_test_between_number_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_between_number feature using the user key - included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_between_number', - {self.attribute_name: self.not_in_between_number})) - - def test_test_between_number_include_off_user(self): - """ - Test that get_treatment returns off for the test_between_number feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_between_number', - {self.attribute_name: self.in_between_number})) - - def test_test_between_number_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_between_number feature using the some key - while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_between_number', - {self.attribute_name: self.in_between_number})) - - def test_test_between_number_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_between_number feature using the some key - while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_between_number', - {self.attribute_name: self.not_in_between_number})) - - def test_test_between_number_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_between_number feature using the some key - and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_between_number')) - - # - # test_equal_to_datetime tests - # - - def test_test_equal_to_datetime_include_on_user(self): - """ - Test that get_treatment returns on for the test_equal_to_datetime feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_equal_to_datetime', - {self.attribute_name: self.equal_to_datetime})) - - def test_test_equal_to_datetime_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_equal_to_datetime feature using the user key - included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_equal_to_datetime', - {self.attribute_name: self.not_equal_to_datetime})) - - def test_test_equal_to_datetime_include_off_user(self): - """ - Test that get_treatment returns off for the test_equal_to_datetime feature using the user - key included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_equal_to_datetime', - {self.attribute_name: self.equal_to_datetime})) - - def test_test_equal_to_datetime_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_equal_to_datetime feature using the some key - while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_equal_to_datetime', - {self.attribute_name: self.equal_to_datetime})) - - def test_test_equal_to_datetime_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_equal_to_datetime feature using the some - key while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_equal_to_datetime', - {self.attribute_name: self.not_equal_to_datetime})) - - def test_test_equal_to_datetime_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_equal_to_datetime feature using the some - key and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_equal_to_datetime')) - - # - # test_equal_to_number tests - # - - def test_test_equal_to_number_include_on_user(self): - """ - Test that get_treatment returns on for the test_equal_to_number feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_equal_to_number', - {self.attribute_name: self.equal_to_number})) - - def test_test_equal_to_number_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_equal_to_number feature using the user key - included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_equal_to_number', - {self.attribute_name: self.not_equal_to_number})) - - def test_test_equal_to_number_include_off_user(self): - """ - Test that get_treatment returns off for the test_equal_to_number feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_equal_to_number', - {self.attribute_name: self.equal_to_number})) - - def test_test_equal_to_number_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_equal_to_number feature using the some key - while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_equal_to_number', - {self.attribute_name: self.equal_to_number})) - - def test_test_equal_to_number_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_equal_to_number feature using the some key - while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_equal_to_number', - {self.attribute_name: self.not_equal_to_number})) - - def test_test_equal_to_number_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_equal_to_number feature using the some key - and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_equal_to_number')) - - # - # test_greater_than_or_equal_to_datetime tests - # - - def test_test_greater_than_or_equal_to_datetime_include_on_user(self): - """ - Test that get_treatment returns on for the test_greater_than_or_equal_to_datetime feature - using the user key included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_greatr_than_or_equal_to_datetime', - {self.attribute_name: self.greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_datetime_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_greater_than_or_equal_to_datetime feature - using the user key included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_greatr_than_or_equal_to_datetime', - {self.attribute_name: self.not_greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_datetime_include_off_user(self): - """ - Test that get_treatment returns off for the test_greater_than_or_equal_to_datetime feature - using the user key included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_greatr_than_or_equal_to_datetime', - {self.attribute_name: self.greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_datetime_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_greater_than_or_equal_to_datetime feature - using the some key while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_greatr_than_or_equal_to_datetime', - {self.attribute_name: self.greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_datetime_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_greater_than_or_equal_to_datetime feature - using the some key while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_greatr_than_or_equal_to_datetime', - {self.attribute_name: self.not_greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_datetime_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_greater_than_or_equal_to_datetime feature - using the some key and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_greatr_than_or_equal_to_datetime')) - - # - # test_greater_than_or_equal_to_number tests - # - - def test_test_greater_than_or_equal_to_number_include_on_user(self): - """ - Test that get_treatment returns on for the test_greater_than_or_equal_to_number feature - using the user key included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_greatr_than_or_equal_to_number', - {self.attribute_name: self.greater_than_or_equal_to_number})) - - def test_test_greater_than_or_equal_to_number_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_greater_than_or_equal_to_number feature - using the user key included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_greatr_than_or_equal_to_number', - {self.attribute_name: self.not_greater_than_or_equal_to_number})) - - def test_test_greater_than_or_equal_to_number_include_off_user(self): - """ - Test that get_treatment returns off for the test_greater_than_or_equal_to_number feature - using the user key included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_greatr_than_or_equal_to_number', - {self.attribute_name: self.greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_number_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_greater_than_or_equal_to_number feature - using the some key while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_greatr_than_or_equal_to_number', - {self.attribute_name: self.greater_than_or_equal_to_datetime})) - - def test_test_greater_than_or_equal_to_number_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_greater_than_or_equal_to_number feature - using the some key while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_greatr_than_or_equal_to_number', - {self.attribute_name: self.not_greater_than_or_equal_to_number})) - - def test_test_greater_than_or_equal_to_number_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_greater_than_or_equal_to_number feature - using the some key and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_greatr_than_or_equal_to_number')) - - # - # test_less_than_or_equal_to_datetime tests - # - - def test_test_less_than_or_equal_to_datetime_include_on_user(self): - """ - Test that get_treatment returns on for the test_less_than_or_equal_to_datetime feature - using the user key included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_less_than_or_equal_to_datetime', - {self.attribute_name: self.less_than_or_equal_to_datetime})) - - def test_test_less_than_or_equal_to_datetime_include_on_user_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_less_than_or_equal_to_datetime feature - using the user key included for on treatment even while there is no attribute match - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_less_than_or_equal_to_datetime', - {self.attribute_name: self.not_less_than_or_equal_to_datetime})) - - def test_test_less_than_or_equal_to_datetime_include_off_user(self): - """ - Test that get_treatment returns off for the test_less_than_or_equal_to_datetime feature - using the user key included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_less_than_or_equal_to_datetime', - {self.attribute_name: self.less_than_or_equal_to_datetime})) - - def test_test_less_than_or_equal_to_datetime_some_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_less_than_or_equal_to_datetime feature - using the some key while the attribute matches (100% for on treatment) - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.some_key, 'test_less_than_or_equal_to_datetime', - {self.attribute_name: self.less_than_or_equal_to_datetime})) - - def test_test_less_than_or_equal_to_datetime_some_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_less_than_or_equal_to_datetime feature - using the some key while the attribute doesn't match (100% for on treatment) - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_less_than_or_equal_to_datetime', - {self.attribute_name: self.not_less_than_or_equal_to_datetime})) - - def test_test_less_than_or_equal_to_datetime_some_key_no_attributes(self): - """ - Test that get_treatment returns off for the test_less_than_or_equal_to_datetime feature - using the some key and no attributes - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.some_key, 'test_less_than_or_equal_to_datetime')) - - # - # test_in_segment tests - # - - def test_test_in_segment_include_on_user(self): - """ - Test that get_treatment returns on for the test_in_segment feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_in_segment')) - - def test_test_in_segment_include_off_user(self): - """ - Test that get_treatment returns off for the test_in_segment feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_in_segment')) - - def test_test_in_segment_in_segment_key(self): - """ - Test that get_treatment returns on for the test_in_segment feature using a key in the - segment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_in_segment, 'test_in_segment')) - - def test_test_in_segment_not_in_segment_key(self): - """ - Test that get_treatment returns off for the test_in_segment feature using a key not in the - segment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_not_in_segment, 'test_in_segment')) - - # - # test_in_segment_multi_treatment tests - # - - def test_test_in_segment_multi_treatment_include_on_user(self): - """ - Test that get_treatment returns on for the test_in_segment_multi_treatment feature using - the user key included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_in_segment_multi_treatment')) - - def test_test_in_segment_multi_treatment_include_off_user(self): - """ - Test that get_treatment returns off for the test_in_segment_multi_treatment feature using - the user key included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_in_segment_multi_treatment')) - - def test_test_in_segment_multi_treatment_include_some_treatment_user(self): - """ - Test that get_treatment returns on for the test_in_segment_multi_treatment feature using - the user key included for some_treatment treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_in_segment_multi_treatment')) - - def test_test_in_segment_multi_treatment_in_segment_key(self): - """ - Test that get_treatment returns on for the test_in_segment_multi_treatment feature using a - key in the segment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_in_segment, 'test_in_segment_multi_treatment')) - - def test_test_in_segment_multi_treatment_not_in_segment_key(self): - """ - Test that get_treatment returns off for the test_in_segment_multi_treatment feature using a - key not in the segment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_not_in_segment, 'test_in_segment_multi_treatment')) - - # - # test_multi_condition tests - # - - def test_test_multi_condition_include_on_user(self): - """ - Test that get_treatment returns on for the test_multi_condition feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_multi_condition', - {self.attribute_name: self.multi_condition_not_equal_to_number})) - - def test_test_multi_condition_include_off_user(self): - """ - Test that get_treatment returns off for the test_multi_condition feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_multi_condition', - {self.attribute_name: self.multi_condition_equal_to_number})) - - def test_test_multi_condition_in_segment_key_no_attribute_match(self): - """ - Test that get_treatment returns on for the test_multi_condition feature using a key in the - segment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_in_segment, 'test_multi_condition', - {self.attribute_name: self.multi_condition_not_equal_to_number})) - - def test_test_multi_condition_not_in_segment_key_attribute_match(self): - """ - Test that get_treatment returns on for the test_multi_condition feature using a key not in - the segment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_not_in_segment, 'test_multi_condition', - {self.attribute_name: self.multi_condition_equal_to_number})) - - def test_test_multi_condition_not_in_segment_key_no_attribute_match(self): - """ - Test that get_treatment returns off for the test_multi_condition feature using a key not in - the segment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_not_in_segment, 'test_multi_condition', - {self.attribute_name: self.multi_condition_not_equal_to_number})) - - # - # test_whitelist tests - # - - def test_test_whitelist_include_on_user(self): - """ - Test that get_treatment returns on for the test_whitelist feature using the user key - included for on treatment - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_whitelist', - {self.attribute_name: self.not_in_whitelist})) - - def test_test_whitelist_include_off_user(self): - """ - Test that get_treatment returns off for the test_whitelist feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_whitelist', - {self.attribute_name: self.not_in_whitelist})) - - def test_test_whitelist_in_whitelist(self): - """ - Test that get_treatment returns on for the test_whitelist feature using an attribute - in the whitelist - """ - self.assertEqual(self.on_treatment, self.client.get_treatment( - self.fake_id_in_segment, 'test_whitelist', - {self.attribute_name: self.in_whitelist})) - - def test_test_whitelist_not_in_whitelist(self): - """ - Test that get_treatment returns off for the test_whitelist feature using an attribute not - in the whitelist - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_not_in_segment, 'test_whitelist', - {self.attribute_name: self.not_in_whitelist})) - - # - # test_killed tests - # - - def test_test_killed_include_on_user(self): - """ - Test that get_treatment returns off for the test_killed feature using the user key - included for on treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_on_key, 'test_killed')) - - def test_test_killed_include_off_user(self): - """ - Test that get_treatment returns off for the test_killed feature using the user key - included for off treatment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_off_key, 'test_killed')) - - def test_test_killed_in_segment_key(self): - """ - Test that get_treatment returns off for the test_killed feature using a key in the - segment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_in_segment, 'test_killed')) - - def test_test_killed_not_in_segment_key(self): - """ - Test that get_treatment returns off for the test_killed feature using a key not in the - segment - """ - self.assertEqual(self.off_treatment, self.client.get_treatment( - self.fake_id_not_in_segment, 'test_killed')) - - -class LocalhostEnvironmentClientParseSplitFileTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_file_name = mock.MagicMock() - self.all_keys_split_side_effect = [mock.MagicMock(), mock.MagicMock()] - self.all_keys_split_mock = self.patch('splitio.brokers.AllKeysSplit', - side_effect=self.all_keys_split_side_effect) - self.build_split_fetcher_mock = self.patch( - 'splitio.tests.test_clients.LocalhostBroker._build_split_fetcher') - - self.open_mock = self.patch_builtin('open') - self.some_config = mock.MagicMock() - self.threading_mock = self.patch('threading.Thread') - self.broker = LocalhostBroker(self.some_config) - - def test_skips_comment_lines(self): - """Test that _parse_split_file skips comment lines""" - self.open_mock.return_value.__enter__.return_value.__iter__.return_value = [ - '#feature treatment'] - self.broker._parse_split_file(self.some_file_name) - self.all_keys_split_mock.assert_not_called() - - def test_skips_illegal_lines(self): - """Test that _parse_split_file skips illegal lines""" - self.open_mock.return_value.__enter__.return_value.__iter__.return_value = [ - '!feature treat$ment'] - self.broker._parse_split_file(self.some_file_name) - self.all_keys_split_mock.assert_not_called() - - def test_parses_definition_lines(self): - """Test that _parse_split_file skips comment lines""" - self.open_mock.return_value.__enter__.return_value.__iter__.return_value = [ - 'feature1 treatment1', 'feature-2 treatment-2'] - self.broker._parse_split_file(self.some_file_name) - self.assertListEqual([mock.call('feature1', 'treatment1'), - mock.call('feature-2', 'treatment-2')], - self.all_keys_split_mock.call_args_list) - - def test_returns_dict_with_parsed_splits(self): - """Test that _parse_split_file skips comment lines""" - self.open_mock.return_value.__enter__.return_value.__iter__.return_value = [ - 'feature1 treatment1', 'feature2 treatment2'] - self.assertDictEqual({'feature1': self.all_keys_split_side_effect[0], - 'feature2': self.all_keys_split_side_effect[1]}, - self.broker._parse_split_file(self.some_file_name)) - - def test_raises_value_error_if_ioerror_is_raised(self): - """Raises a ValueError if an IOError is raised""" - self.open_mock.side_effect = IOError() - with self.assertRaises(ValueError): - self.broker._parse_split_file(self.some_file_name) - - -class LocalhostBrokerOffTheGrid(TestCase): - """ - Tests for LocalhostEnvironmentClient. Auto update config behaviour - """ - def test_auto_update_splits(self): - """ - Verifies that the split file is automatically re-parsed as soon as it's - modified - """ - with tempfile.NamedTemporaryFile(mode='w') as split_file: - split_file.write('a_test_split off\n') - split_file.flush() - - factory = get_factory("localhost", split_definition_file_name=split_file.name) - client = factory.client() - self.assertEqual(client.get_treatment('x', 'a_test_split'), 'off') - - split_file.truncate() - split_file.write('a_test_split on\n') - split_file.flush() - sleep(5) - - self.assertEqual(client.get_treatment('x', 'a_test_split'), 'on') - client.destroy() - - -class TestClientDestroy(TestCase): - """ - """ - - def setUp(self): - self.some_api_key = mock.MagicMock() - self.some_config = mock.MagicMock() - - def test_self_refreshing_destroy(self): - broker = SelfRefreshingBroker(self.some_api_key) - client = Client(broker) - manager = SelfRefreshingSplitManager(broker) - manager._logger.error = mock.MagicMock() - logger_error = manager._logger.error - client.destroy() - self.assertEqual(client.get_treatment('asd', 'asd'), CONTROL) - self.assertEqual(manager.splits(), []) - result = client.get_treatments('asd', [None, 'asd']) - self.assertEqual(len(result.keys()), 1) - self.assertEqual(result["asd"], CONTROL) - logger_error \ - .assert_called_with("Client has already been destroyed - no calls possible.") - - def test_redis_destroy(self): - broker = RedisBroker(self.some_api_key, self.some_config) - client = Client(broker) - manager = RedisSplitManager(broker) - manager._logger.error = mock.MagicMock() - logger_error = manager._logger.error - client.destroy() - self.assertEqual(client.get_treatment('asd', 'asd'), CONTROL) - self.assertEqual(manager.splits(), []) - result = client.get_treatments('asd', [True, 'asd', None]) - self.assertEqual(len(result.keys()), 1) - self.assertEqual(result["asd"], CONTROL) - logger_error \ - .assert_called_with("Client has already been destroyed - no calls possible.") - - def test_uwsgi_destroy(self): - broker = UWSGIBroker(self.some_api_key, {'eventsQueueSize': 30}) - client = Client(broker) - manager = UWSGISplitManager(broker) - manager._logger.error = mock.MagicMock() - logger_error = manager._logger.error - client.destroy() - self.assertEqual(client.get_treatment('asd', 'asd'), CONTROL) - self.assertEqual(manager.splits(), []) - result = client.get_treatments('asd', ['asd', None]) - self.assertEqual(result["asd"], CONTROL) - self.assertEqual(len(result.keys()), 1) - logger_error \ - .assert_called_with("Client has already been destroyed - no calls possible.") diff --git a/splitio/tests/test_events.py b/splitio/tests/test_events.py deleted file mode 100644 index 83383e78..00000000 --- a/splitio/tests/test_events.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Unit tests for the api module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase -from splitio.events import InMemoryEventStorage - - -a = 1 - - -class EventsInmemoryStorageTests(TestCase): - def test_hook_is_called_when_queue_is_full(self): - global a - def ftemp(): - global a - a += 1 - storage = InMemoryEventStorage(5) - storage.set_queue_full_hook(ftemp) - storage.log_event("a") - storage.log_event("a") - storage.log_event("a") - storage.log_event("a") - storage.log_event("a") - storage.log_event("a") - self.assertEqual(a, 2) diff --git a/splitio/tests/test_factories.py b/splitio/tests/test_factories.py deleted file mode 100644 index 6bcdae90..00000000 --- a/splitio/tests/test_factories.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase - -from splitio import get_factory -from splitio.redis_support import SentinelConfigurationException - - -class RedisSentinelFactory(TestCase): - def test_redis_factory_with_empty_sentinels_array(self): - config = { - 'redisDb': 0, - 'redisPrefix': 'test', - 'redisSentinels': [], - 'redisMasterService': 'mymaster', - 'redisSocketTimeout': 3 - } - - with self.assertRaises(SentinelConfigurationException): - get_factory('abc', config=config) - - def test_redis_factory_with_wrong_type_in_sentinels_array(self): - config = { - 'redisDb': 0, - 'redisPrefix': 'test', - 'redisSentinels': 'abc', - 'redisMasterService': 'mymaster', - 'redisSocketTimeout': 3 - } - - with self.assertRaises(SentinelConfigurationException): - get_factory('abc', config=config) - - def test_redis_factory_with_wrong_data_in_sentinels_array(self): - config = { - 'redisDb': 0, - 'redisPrefix': 'test', - 'redisSentinels': ['asdasd'], - 'redisMasterService': 'mymaster', - 'redisSocketTimeout': 3 - } - - with self.assertRaises(SentinelConfigurationException): - get_factory('abc', config=config) - - def test_redis_factory_with_without_master_service(self): - config = { - 'redisDb': 0, - 'redisPrefix': 'test', - 'redisSentinels': [('test', 1234)], - 'redisSocketTimeout': 3 - } - - with self.assertRaises(SentinelConfigurationException): - get_factory('abc', config=config) diff --git a/splitio/tests/test_get_treatments.py b/splitio/tests/test_get_treatments.py deleted file mode 100644 index 0f4a67fb..00000000 --- a/splitio/tests/test_get_treatments.py +++ /dev/null @@ -1,105 +0,0 @@ -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from os.path import dirname, join -from json import load -from unittest import TestCase - -from splitio.clients import Client -from splitio.redis_support import (RedisSplitCache, get_redis) -from splitio.brokers import RedisBroker -from splitio import get_factory - - -class GetTreatmentsTest(TestCase): - def setUp(self): - self._some_config = mock.MagicMock() - self._split_changes_file_name = join(dirname(__file__), - 'splitGetTreatments.json') - - with open(self._split_changes_file_name) as f: - self._json = load(f) - split_definition = self._json['splits'][0] - split_name = split_definition['name'] - - self._redis = get_redis({'redisPrefix': 'getTreatmentsTest'}) - - self._redis_split_cache = RedisSplitCache(self._redis) - self._redis_split_cache.add_split(split_name, split_definition) - self._client = Client(RedisBroker(self._redis, self._some_config)) - - self._config = { - 'ready': 180000, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'getTreatmentsTest' - } - self._factory = get_factory('asdqwe123456', config=self._config) - self._split = self._factory.client() - - def test_clien_with_distinct_features(self): - results = self._split.get_treatments('some_key', ['some_feature', 'some_feature_2']) - self.assertIn('some_feature', results) - self.assertIn('some_feature_2', results) - self.assertDictEqual(results, { - 'some_feature': 'control', - 'some_feature_2': 'control' - }) - - def test_clien_with_repeated_features(self): - results = self._split.get_treatments('some_key', ['some_feature', 'some_feature_2', - 'some_feature', 'some_feature']) - self.assertIn('some_feature', results) - self.assertIn('some_feature_2', results) - self.assertDictEqual(results, { - 'some_feature': 'control', - 'some_feature_2': 'control' - }) - - def test_clien_with_none_and_repeated_features(self): - results = self._split.get_treatments('some_key', ['some_feature', None, 'some_feature_2', - 'some_feature', 'some_feature', None]) - self.assertIn('some_feature', results) - self.assertIn('some_feature_2', results) - self.assertDictEqual(results, { - 'some_feature': 'control', - 'some_feature_2': 'control' - }) - - def test_client_with_valid_none_and_repeated_features_and_invalid_key(self): - features = ['some_feature', 'get_treatments_test', 'some_feature_2', - 'some_feature', 'get_treatments_test', None, 'valid'] - results = self._split.get_treatments('some_key', features) - self.assertIn('some_feature', results) - self.assertIn('some_feature_2', results) - self.assertIn('get_treatments_test', results) - self.assertEqual(results['some_feature'], 'control') - self.assertEqual(results['some_feature_2'], 'control') - self.assertEqual(results['get_treatments_test'], 'off') - - def test_client_with_valid_none_and_repeated_features_and_valid_key(self): - features = ['some_feature', 'get_treatments_test', 'some_feature_2', - 'some_feature', 'get_treatments_test', None, 'valid'] - results = self._split.get_treatments('valid', features) - self.assertIn('some_feature', results) - self.assertIn('some_feature_2', results) - self.assertIn('get_treatments_test', results) - self.assertEqual(results['some_feature'], 'control') - self.assertEqual(results['some_feature_2'], 'control') - self.assertEqual(results['get_treatments_test'], 'on') - - def test_client_with_valid_none_invalid_and_repeated_features_and_valid_key(self): - features = ['some_feature', 'get_treatments_test', 'some_feature_2', - 'some_feature', 'get_treatments_test', None, 'valid', - True, [], True] - results = self._split.get_treatments('valid', features) - self.assertIn('some_feature', results) - self.assertIn('some_feature_2', results) - self.assertIn('get_treatments_test', results) - self.assertEqual(results['some_feature'], 'control') - self.assertEqual(results['some_feature_2'], 'control') - self.assertEqual(results['get_treatments_test'], 'on') diff --git a/splitio/tests/test_impression_listener.py b/splitio/tests/test_impression_listener.py deleted file mode 100644 index 703cd0ae..00000000 --- a/splitio/tests/test_impression_listener.py +++ /dev/null @@ -1,196 +0,0 @@ -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from splitio.config import SDK_VERSION, GLOBAL_KEY_PARAMETERS - -from os.path import dirname, join -from json import load -from unittest import TestCase - -from splitio.clients import Client -from splitio.redis_support import (RedisSplitCache, get_redis) -from splitio.brokers import RedisBroker -from splitio.impressions import (Impression, ImpressionListener, ImpressionListenerWrapper, - ImpressionListenerException) -from splitio import get_factory - - -class ImpressionListenerClient(ImpressionListener): - def log_impression(self, data): - self._data_logged = data - - def get_impression(self): - return self._data_logged - - -class ImpressionListenerClientWithException(ImpressionListener): - def log_impression(self, data): - raise Exception('Simulate exception.') - - -class ImpressionListenerClientEmpty: - pass - - -class CustomImpressionListenerTestOnRedis(TestCase): - def setUp(self): - self._some_config = mock.MagicMock() - self._split_changes_file_name = join(dirname(__file__), - 'splitCustomImpressionListener.json') - - with open(self._split_changes_file_name) as f: - self._json = load(f) - split_definition = self._json['splits'][0] - split_name = split_definition['name'] - - self._redis = get_redis({'redisPrefix': 'customImpressionListenerTest'}) - - self._redis_split_cache = RedisSplitCache(self._redis) - self._redis_split_cache.add_split(split_name, split_definition) - self._client = Client(RedisBroker(self._redis, self._some_config)) - - self.some_feature = 'feature_0' - self.some_impression_0 = Impression(matching_key=mock.MagicMock(), - feature_name=self.some_feature, - treatment=mock.MagicMock(), - label=mock.MagicMock(), - change_number=mock.MagicMock(), - bucketing_key=mock.MagicMock(), - time=mock.MagicMock()) - - def test_client_raise_attribute_error(self): - client_1 = Client(RedisBroker(self._redis, self._some_config), - True, ImpressionListenerClientEmpty()) - - with self.assertRaises(AttributeError): - client_1._impression_listener.log_impression(self.some_impression_0) - - def test_send_data_to_client(self): - impression_client = ImpressionListenerClient() - impression_wrapper = ImpressionListenerWrapper(impression_client) - - impression_wrapper.log_impression(self.some_impression_0) - - self.assertIn('impression', impression_client._data_logged) - impression_logged = impression_client._data_logged['impression'] - self.assertIsInstance(impression_logged, Impression) - self.assertDictEqual({ - 'impression': { - 'keyName': self.some_impression_0.matching_key, - 'treatment': self.some_impression_0.treatment, - 'time': self.some_impression_0.time, - 'changeNumber': self.some_impression_0.change_number, - 'label': self.some_impression_0.label, - 'bucketingKey': self.some_impression_0.bucketing_key - } - }, { - 'impression': { - 'keyName': impression_logged.matching_key, - 'treatment': impression_logged.treatment, - 'time': impression_logged.time, - 'changeNumber': impression_logged.change_number, - 'label': impression_logged.label, - 'bucketingKey': impression_logged.bucketing_key - } - }) - - self.assertIn('instance-id', impression_client._data_logged) - self.assertEqual(impression_client._data_logged['instance-id'], - GLOBAL_KEY_PARAMETERS['instance-id']) - - self.assertIn('sdk-language-version', impression_client._data_logged) - self.assertEqual(impression_client._data_logged['sdk-language-version'], SDK_VERSION) - - self.assertIn('attributes', impression_client._data_logged) - - def test_client_throwing_exception_in_listener(self): - impressionListenerClient = ImpressionListenerClientWithException() - - config = { - 'ready': 180000, - 'impressionListener': impressionListenerClient, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'customImpressionListenerTest' - } - factory = get_factory('asdqwe123456', config=config) - split = factory.client() - - self.assertEqual(split.get_treatment('valid', 'iltest'), 'on') - - def test_client(self): - impressionListenerClient = ImpressionListenerClient() - - config = { - 'ready': 180000, - 'impressionListener': impressionListenerClient, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'customImpressionListenerTest' - } - factory = get_factory('asdqwe123456', config=config) - split = factory.client() - - self.assertEqual(split.get_treatment('valid', 'iltest'), 'on') - self.assertEqual(split.get_treatment('invalid', 'iltest'), 'off') - self.assertEqual(split.get_treatment('valid', 'iltest_invalid'), 'control') - - def test_client_without_impression_listener(self): - config = { - 'ready': 180000, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'customImpressionListenerTest' - } - factory = get_factory('asdqwe123456', config=config) - split = factory.client() - - self.assertEqual(split.get_treatment('valid', 'iltest'), 'on') - self.assertEqual(split.get_treatment('invalid', 'iltest'), 'off') - self.assertEqual(split.get_treatment('valid', 'iltest_invalid'), 'control') - - def test_client_when_impression_listener_is_none(self): - config = { - 'ready': 180000, - 'redisDb': 0, - 'impressionListener': None, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'customImpressionListenerTest' - } - factory = get_factory('asdqwe123456', config=config) - split = factory.client() - - self.assertEqual(split.get_treatment('valid', 'iltest'), 'on') - self.assertEqual(split.get_treatment('invalid', 'iltest'), 'off') - self.assertEqual(split.get_treatment('valid', 'iltest_invalid'), 'control') - - def test_client_with_empty_impression_listener(self): - config = { - 'ready': 180000, - 'redisDb': 0, - 'impressionListener': ImpressionListenerClientEmpty(), - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'customImpressionListenerTest' - } - factory = get_factory('asdqwe123456', config=config) - split = factory.client() - - self.assertEqual(split.get_treatment('valid', 'iltest'), 'on') - self.assertEqual(split.get_treatment('invalid', 'iltest'), 'off') - self.assertEqual(split.get_treatment('valid', 'iltest_invalid'), 'control') - - def test_throwing_exception_in_listener(self): - impression_exception = ImpressionListenerClientWithException() - - impression_wrapper = ImpressionListenerWrapper(impression_exception) - - with self.assertRaises(ImpressionListenerException): - impression_wrapper.log_impression(self.some_impression_0) diff --git a/splitio/tests/test_impressions.py b/splitio/tests/test_impressions.py deleted file mode 100644 index 8957641a..00000000 --- a/splitio/tests/test_impressions.py +++ /dev/null @@ -1,691 +0,0 @@ -"""Unit tests for the impressions module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase -from itertools import groupby -from jsonpickle import decode - -from splitio.config import SDK_VERSION -from splitio.impressions import (Impression, build_impressions_data, TreatmentLog, - LoggerBasedTreatmentLog, InMemoryTreatmentLog, - CacheBasedTreatmentLog, SelfUpdatingTreatmentLog, - AsyncTreatmentLog) -from splitio.tests.utils import MockUtilsMixin - -from splitio.redis_support import (get_redis, IMPRESSIONS_QUEUE_KEY, - RedisImpressionsCache, IMPRESSION_KEY_DEFAULT_TTL) -from splitio import get_factory -from splitio.config import GLOBAL_KEY_PARAMETERS - - -class BuildImpressionsDataTests(TestCase): - def setUp(self): - self.some_feature = 'feature_0' - self.some_other_feature = 'feature_1' - self.some_impression_0 = Impression(matching_key=mock.MagicMock(), - feature_name=self.some_feature, - treatment=mock.MagicMock(), - label=mock.MagicMock(), - change_number=mock.MagicMock(), - bucketing_key=mock.MagicMock(), - time=mock.MagicMock()) - self.some_impression_1 = Impression(matching_key=mock.MagicMock(), - feature_name=self.some_other_feature, - treatment=mock.MagicMock(), - label=mock.MagicMock(), - change_number=mock.MagicMock(), - bucketing_key=mock.MagicMock(), - time=mock.MagicMock()) - self.some_impression_2 = Impression(matching_key=mock.MagicMock(), - feature_name=self.some_other_feature, - treatment=mock.MagicMock(), - label=mock.MagicMock(), - change_number=mock.MagicMock(), - bucketing_key=mock.MagicMock(), - time=mock.MagicMock()) - - def test_build_impressions_data_works(self): - """Tests that build_impressions_data works""" - - impressions = [self.some_impression_0, self.some_impression_1, self.some_impression_2] - grouped_impressions = groupby(impressions, key=lambda impression: impression.feature_name) - - impression_dict = dict((feature_name, list(group)) for feature_name, group - in grouped_impressions) - - result = build_impressions_data(impression_dict) - - self.assertIsInstance(result, list) - self.assertEqual(2, len(result)) - - result = sorted(result, key=lambda d: d['testName']) - - self.assertDictEqual({ - 'testName': self.some_feature, - 'keyImpressions': [ - { - 'keyName': self.some_impression_0.matching_key, - 'treatment': self.some_impression_0.treatment, - 'time': self.some_impression_0.time, - 'changeNumber': self.some_impression_0.change_number, - 'label': self.some_impression_0.label, - 'bucketingKey': self.some_impression_0.bucketing_key - } - ] - }, result[0]) - self.assertDictEqual({ - 'testName': self.some_other_feature, - 'keyImpressions': [ - { - 'keyName': self.some_impression_1.matching_key, - 'treatment': self.some_impression_1.treatment, - 'time': self.some_impression_1.time, - 'changeNumber': self.some_impression_1.change_number, - 'label': self.some_impression_1.label, - 'bucketingKey': self.some_impression_1.bucketing_key - }, - { - 'keyName': self.some_impression_2.matching_key, - 'treatment': self.some_impression_2.treatment, - 'time': self.some_impression_2.time, - 'changeNumber': self.some_impression_2.change_number, - 'label': self.some_impression_2.label, - 'bucketingKey': self.some_impression_2.bucketing_key - } - ] - }, result[1]) - - def test_build_impressions_data_skipts_features_with_no_impressions(self): - """Tests that build_impressions_data skips features with no impressions""" - - grouped_impressions = groupby([self.some_impression_0], - key=lambda impression: impression.feature_name) - - impression_dict = dict((feature_name, list(group)) for feature_name, group - in grouped_impressions) - - result = build_impressions_data(impression_dict) - - self.assertIsInstance(result, list) - self.assertEqual(1, len(result)) - - self.assertDictEqual( - { - 'testName': self.some_impression_0.feature_name, - 'keyImpressions': [ - { - 'keyName': self.some_impression_0.matching_key, - 'treatment': self.some_impression_0.treatment, - 'time': self.some_impression_0.time, - 'changeNumber': self.some_impression_0.change_number, - 'label': self.some_impression_0.label, - 'bucketingKey': self.some_impression_0.bucketing_key - } - ] - }, result[0]) - - -class TreatmentLogTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_treatment = mock.MagicMock() - self.some_time = 123456 - self.treatment_log = TreatmentLog() - self.log_mock = self.patch_object(self.treatment_log, '_log') - - self.some_label = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.some_impression = Impression(matching_key=self.some_key, - feature_name=self.some_feature_name, - treatment=self.some_treatment, - label=self.some_label, - change_number=self.some_change_number, - bucketing_key=self.some_key, - time=self.some_time) - - def test_log_doesnt_call_internal_log_if_key_is_none(self): - """Tests that log doesn't call _log if key is None""" - impression = Impression(matching_key=None, feature_name=self.some_feature_name, - treatment=self.some_treatment, label=self.some_label, - change_number=self.some_change_number, bucketing_key=self.some_key, - time=self.some_time) - - self.treatment_log.log(impression) - self.log_mock.assert_not_called() - - def test_log_doesnt_call_internal_log_if_feature_name_is_none(self): - """Tests that log doesn't call _log if feature name is None""" - impression = Impression(matching_key=self.some_key, feature_name=None, - treatment=self.some_treatment, label=self.some_label, - change_number=self.some_change_number, bucketing_key=self.some_key, - time=self.some_time) - self.treatment_log.log(impression) - self.log_mock.assert_not_called() - - def test_log_doesnt_call_internal_log_if_treatment_is_none(self): - """Tests that log doesn't call _log if treatment is None""" - impression = Impression(matching_key=self.some_key, feature_name=self.some_feature_name, - treatment=None, label=self.some_label, - change_number=self.some_change_number, bucketing_key=self.some_key, - time=self.some_time) - - self.treatment_log.log(impression) - self.log_mock.assert_not_called() - - def test_log_doesnt_call_internal_log_if_time_is_none(self): - """Tests that log doesn't call _log if time is None""" - impression = Impression(matching_key=self.some_key, feature_name=self.some_feature_name, - treatment=self.some_treatment, label=self.some_label, - change_number=self.some_change_number, bucketing_key=self.some_key, - time=None) - - self.treatment_log.log(impression) - self.log_mock.assert_not_called() - - def test_log_doesnt_call_internal_log_if_time_is_lt_0(self): - """Tests that log doesn't call _log if time is less than 0""" - impression = Impression(matching_key=self.some_key, feature_name=self.some_feature_name, - treatment=self.some_treatment, label=self.some_label, - change_number=self.some_change_number, bucketing_key=self.some_key, - time=-1) - - self.treatment_log.log(impression) - self.log_mock.assert_not_called() - - -class LoggerBasedTreatmentLogTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_treatment = mock.MagicMock() - self.some_time = mock.MagicMock() - self.logger_mock = self.patch('splitio.impressions.logging.getLogger').return_value - self.treatment_log = LoggerBasedTreatmentLog() - - self.some_label = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.some_impression = Impression(matching_key=self.some_key, - feature_name=self.some_feature_name, - treatment=self.some_treatment, - label=self.some_label, - change_number=self.some_change_number, - bucketing_key=self.some_key, - time=self.some_time) - - def test_log_calls_logger_info(self): - """Tests that log calls logger info""" - self.treatment_log._log(self.some_impression) - self.logger_mock.info.assert_called_once_with(mock.ANY, self.some_feature_name, - self.some_key, self.some_treatment, - self.some_time, self.some_label, - self.some_change_number, - self.some_key) - - -class InMemoryTreatmentLogTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_treatment = mock.MagicMock() - self.some_time = mock.MagicMock() - self.deepcopy_mock = self.patch('splitio.impressions.deepcopy') - self.defaultdict_mock_side_effect = [ - mock.MagicMock(), # during __init__ - mock.MagicMock() # afterwards - ] - self.defaultdict_mock = self.patch('splitio.impressions.defaultdict', - side_effect=self.defaultdict_mock_side_effect) - self.rlock_mock = self.patch('splitio.impressions.RLock') - self.treatment_log = InMemoryTreatmentLog() - self.notify_eviction_mock = self.patch_object(self.treatment_log, '_notify_eviction') - - self.some_label = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.some_impression = Impression(matching_key=self.some_key, - feature_name=self.some_feature_name, - treatment=self.some_treatment, - label=self.some_label, - change_number=self.some_change_number, - bucketing_key=self.some_key, - time=self.some_time) - - def test_impressions_is_defaultdict(self): - """Tests that impressions is a defaultdict""" - self.assertEqual(self.defaultdict_mock_side_effect[0], self.treatment_log._impressions) - - def test_fetch_all_and_clear_calls_deepcopy(self): - """Tests that fetch all and clear calls deepcopy on impressions""" - self.treatment_log.fetch_all_and_clear() - self.deepcopy_mock.assert_called_once_with(self.defaultdict_mock_side_effect[0]) - - def test_fetch_all_and_clear_returns_deepcopy_of_impressions(self): - """Tests that fetch all and clear returns deepcopy of impressions""" - self.assertEqual(self.deepcopy_mock.return_value, self.treatment_log.fetch_all_and_clear()) - - def test_fetch_all_and_clear_clears_impressions(self): - """Tests that fetch all and clear clears impressions""" - self.treatment_log.fetch_all_and_clear() - self.assertEqual(self.defaultdict_mock_side_effect[1], self.treatment_log._impressions) - - def test_log_calls_appends_impression_to_feature_entry(self): - """Tests that _log appends an impression to the feature name entry in the impressions - dictionary""" - self.treatment_log._log(self.some_impression) - impressions = self.treatment_log._impressions - impressions.__getitem__.assert_called_once_with(self.some_feature_name) - impressions.__getitem__.return_value.append.assert_called_once_with(self.some_impression) - - def test_log_resets_impressions_if_max_count_reached(self): - """Tests that _log resets impressions if max_count is reached""" - self.treatment_log._max_count = 5 - impressions = self.treatment_log._impressions - impressions.__getitem__.return_value.__len__.return_value = 10 - self.treatment_log._log(self.some_impression) - impressions.__setitem__.assert_called_once_with( - self.some_feature_name, [self.some_impression]) - - def test_log_calls__notify_eviction_if_max_count_reached(self): - """Tests that _log calls _notify_eviction if max_count is reached""" - self.treatment_log._max_count = 5 - impressions = self.treatment_log._impressions - impressions.__getitem__.return_value.__len__.return_value = 10 - self.treatment_log._log(self.some_impression) - self.notify_eviction_mock.assert_called_once_with(self.some_feature_name, - impressions.__getitem__.return_value) - - -class CacheBasedTreatmentLogTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_treatment = mock.MagicMock() - self.some_time = mock.MagicMock() - self.some_impressions_cache = mock.MagicMock() - self.treatment_log = CacheBasedTreatmentLog(self.some_impressions_cache) - self.some_label = mock.MagicMock() - self.some_change_number = mock.MagicMock() - - self.some_impression = Impression(matching_key=self.some_key, - feature_name=self.some_feature_name, - treatment=self.some_treatment, - label=self.some_label, - change_number=self.some_change_number, - bucketing_key=self.some_key, - time=self.some_time) - - def test_log_calls_cache_add_impression(self): - """Tests that _log calls add_impression on cache""" - self.treatment_log._log(self.some_impression) - self.some_impressions_cache.add_impression(self.some_impression) - - -class SelfUpdatingTreatmentLogTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_api = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.treatment_log = SelfUpdatingTreatmentLog(self.some_api, interval=self.some_interval) - self.timer_refresh_mock = self.patch_object(self.treatment_log, '_timer_refresh') - - def test_start_calls_timer_refresh_if_stopped_true(self): - """Test that start calls _timer_refresh if stopped is True""" - self.treatment_log.stopped = True - self.treatment_log.start() - self.timer_refresh_mock.assert_called_once_with() - - def test_start_sets_stopped_to_false_if_stopped_true(self): - """Test that start sets stopped to False if stopped is True before""" - self.treatment_log.stopped = True - self.treatment_log.start() - self.assertFalse(self.treatment_log.stopped) - - def test_start_doesnt_call_timer_refresh_if_stopped_false(self): - """Test that start doesn't call _timer_refresh if stopped is False""" - self.treatment_log.stopped = False - self.treatment_log.start() - self.timer_refresh_mock.assert_not_called() - - -class SelfUpdatingTreatmentLogTimerRefreshTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_api = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.timer_mock = self.patch('splitio.impressions.Timer') - self.thread_pool_executor = self.patch('splitio.impressions.ThreadPoolExecutor') - self.treatment_log = SelfUpdatingTreatmentLog(self.some_api, interval=self.some_interval) - self.treatment_log.stopped = False - - def test_calls_submit(self): - """Test that _timer_refresh calls submit on the executor pool if it is not stopped""" - self.treatment_log._timer_refresh() - self.thread_pool_executor.return_value.submit.assert_called_once_with( - self.treatment_log._update_impressions) - - def test_creates_timer_with_fixed_interval(self): - """Test that _timer_refresh creates a timer with fixed interval if it isn't callable if it - is not stopped""" - self.treatment_log._interval = mock.NonCallableMagicMock() - self.treatment_log._timer_refresh() - self.timer_mock.assert_called_once_with(self.treatment_log._interval, - self.treatment_log._timer_refresh) - - def test_creates_timer_with_randomized_interval(self): - """Test that _timer_refresh creates a timer with interval return value if it is callable - and it is not stopped""" - self.treatment_log._timer_refresh() - self.timer_mock.assert_called_once_with(self.treatment_log._interval.return_value, - self.treatment_log._timer_refresh) - - def test_creates_timer_even_if_worker_thread_raises_exception(self): - """Test that _timer_refresh creates a timer even if an exception is raised submiting to the - executor pool""" - self.thread_pool_executor.return_value.submit.side_effect = Exception() - self.treatment_log._timer_refresh() - self.timer_mock.assert_called_once_with(self.treatment_log._interval.return_value, - self.treatment_log._timer_refresh) - - def test_starts_timer(self): - """Test that _timer_refresh starts the timer if it is not stopped""" - self.treatment_log._timer_refresh() - self.timer_mock.return_value.start.assert_called_once_with() - - def test_stopped_if_timer_raises_exception(self): - """Test that _timer_refresh stops the refresh if an exception is raise setting up the timer - """ - self.timer_mock.side_effect = Exception - self.treatment_log._timer_refresh() - self.assertTrue(self.treatment_log.stopped) - - -class SelfUpdatingTreatmentLogNotifyEvictionTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_api = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_feature_impressions = [mock.MagicMock()] - self.thread_pool_executor_mock = self.patch('splitio.impressions.ThreadPoolExecutor') - self.treatment_log = SelfUpdatingTreatmentLog(self.some_api) - - def test_doesnt_call_submit_if_feature_name_is_none(self): - """Test that _notify_eviction doesn't call the executor submit if feature_name is None""" - self.treatment_log._notify_eviction(None, self.some_feature_impressions) - self.thread_pool_executor_mock.return_value.submit.assert_not_called() - - def test_doesnt_call_submit_if_feature_impressions_is_none(self): - """Test that _notify_eviction doesn't call the executor submit if feature_impressions is - None""" - self.treatment_log._notify_eviction(self.some_feature_name, None) - self.thread_pool_executor_mock.return_value.submit.assert_not_called() - - def test_doesnt_call_submit_if_feature_impressions_is_empty(self): - """Test that _notify_eviction doesn't call the executor submit if feature_impressions is - empty""" - self.treatment_log._notify_eviction(self.some_feature_name, []) - self.thread_pool_executor_mock.return_value.submit.assert_not_called() - - def test_calls_submit(self): - """Test that _notify_eviction calls submit on the executor""" - self.treatment_log._notify_eviction(self.some_feature_name, self.some_feature_impressions) - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.treatment_log._update_evictions, self.some_feature_name, - self.some_feature_impressions) - - -class SelfUpdatingTreatmentLogUpdateImpressionsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_api = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.build_impressions_data_mock = self.patch( - 'splitio.impressions.build_impressions_data', - return_value=[mock.MagicMock(), mock.MagicMock()]) - self.treatment_log = SelfUpdatingTreatmentLog(self.some_api, interval=self.some_interval) - self.fetch_all_and_clear_mock = self.patch_object( - self.treatment_log, 'fetch_all_and_clear') - - def test_calls_fetch_all_and_clear(self): - """Test that _update_impressions call fetch_all_and_clear""" - self.treatment_log._update_impressions() - self.fetch_all_and_clear_mock.assert_called_once_with() - - def test_calls_build_impressions_data(self): - """Test that _update_impressions call build_impressions_data""" - self.treatment_log._update_impressions() - self.build_impressions_data_mock.assert_called_once_with( - self.fetch_all_and_clear_mock.return_value) - - def test_calls_test_impressions(self): - """Test that _update_impressions call test_impressions on the api""" - self.treatment_log._update_impressions() - self.some_api.test_impressions.assert_called_once_with( - self.build_impressions_data_mock.return_value) - - def test_doesnt_call_test_impressions_with_empty_data(self): - """Test that _update_impressions doesn't call test_impressions on the api""" - self.build_impressions_data_mock.return_value = [] - self.treatment_log._update_impressions() - self.some_api.test_impressions.assert_not_called() - - -class SelfUpdatingTreatmentLogUpdateEvictionsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_api = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_feature_impressions = [mock.MagicMock()] - self.build_impressions_data_mock = self.patch( - 'splitio.impressions.build_impressions_data', - return_value=[mock.MagicMock(), mock.MagicMock()]) - self.treatment_log = SelfUpdatingTreatmentLog(self.some_api, interval=self.some_interval) - - def test_calls_build_impressions_data(self): - """Test that _update_evictions calls build_impressions_data_mock""" - self.treatment_log._update_evictions(self.some_feature_name, self.some_feature_impressions) - self.build_impressions_data_mock.assert_called_once_with( - {self.some_feature_name: self.some_feature_impressions}) - - def test_calls_test_impressions(self): - """Test that _update_evictions calls test_impressions on the API client""" - self.treatment_log._update_evictions(self.some_feature_name, self.some_feature_impressions) - self.some_api.test_impressions.assert_called_once_with( - self.build_impressions_data_mock.return_value) - - def test_doesnt_call_test_impressions_if_data_is_empty(self): - """Test that _update_evictions calls test_impressions on the API client""" - self.build_impressions_data_mock.return_value = [] - self.treatment_log._update_evictions(self.some_feature_name, self.some_feature_impressions) - self.some_api.test_impressions.assert_not_called() - - def test_doesnt_raise_exceptions(self): - """Test that _update_evictions doesn't raise exceptions when the API client does""" - self.some_api.test_impressions.side_effect = Exception() - try: - self.treatment_log._update_evictions(self.some_feature_name, - self.some_feature_impressions) - except Exception: - self.fail('Unexpected exception raised') - - -class AsyncTreatmentLogTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_feature_name = mock.MagicMock() - self.some_treatment = mock.MagicMock() - self.some_label = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.some_time = mock.MagicMock() - self.some_max_workers = mock.MagicMock() - self.some_delegate_treatment_log = mock.MagicMock() - self.thread_pool_executor_mock = self.patch('splitio.impressions.ThreadPoolExecutor') - self.treatment_log = AsyncTreatmentLog(self.some_delegate_treatment_log, - max_workers=self.some_max_workers) - - self.some_impression = Impression(matching_key=self.some_key, - feature_name=self.some_feature_name, - treatment=self.some_treatment, - label=self.some_label, - change_number=self.some_change_number, - bucketing_key=self.some_key, - time=self.some_time) - - def test_log_calls_thread_pool_executor_submit(self): - """Tests that log calls submit on the thread pool executor""" - self.treatment_log.log(self.some_impression) - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.some_delegate_treatment_log.log, self.some_impression) - - def test_log_doesnt_raise_exceptions_if_submit_does(self): - """Tests that log doesn't raise exceptions when submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - self.treatment_log.log(self.some_impression) - # - # try: - # except: - # self.fail('Unexpected exception raised') - - -class ImpressionAsQueueTests(TestCase, MockUtilsMixin): - def test_impression_added(self): - self._some_config = mock.MagicMock() - - self._redis = get_redis({'redisPrefix': 'singleQueueTests'}) - self._redis.delete(IMPRESSIONS_QUEUE_KEY) - - self._config = { - 'ready': 180000, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'singleQueueTests' - } - - self._factory = get_factory('asdqwe123456', config=self._config) - - impression = Impression(matching_key='some_matching_key', - feature_name='some_feature_name', - treatment='on', - label='label1', - change_number=123456, - bucketing_key='some_bucketing_key', - time=123456) - - self.an_impressions_cache = RedisImpressionsCache(self._redis) - - self.an_impressions_cache._logger.debug = mock.MagicMock() - logger_debug = self.an_impressions_cache._logger.debug - - self.an_impressions_cache.add_impressions([impression]) - - logger_debug \ - .assert_called_once_with("SET EXPIRE KEY FOR QUEUE") - - # Assert that the TTL is within a 10-second range (between it was set and retrieved). - ttl = self._redis.ttl(IMPRESSIONS_QUEUE_KEY) - - self.assertLessEqual(int(ttl), IMPRESSION_KEY_DEFAULT_TTL) - self.assertGreaterEqual(int(ttl), IMPRESSION_KEY_DEFAULT_TTL - 10) - - impression_stored = decode(self._redis.rpop(IMPRESSIONS_QUEUE_KEY)) - self.assertEqual(impression_stored['m']['i'], GLOBAL_KEY_PARAMETERS['ip-address']) - self.assertEqual(impression_stored['m']['s'], SDK_VERSION) - self.assertEqual(impression_stored['m']['n'], GLOBAL_KEY_PARAMETERS['instance-id']) - self.assertEqual(impression_stored['i']['k'], 'some_matching_key') - self.assertEqual(impression_stored['i']['f'], 'some_feature_name') - self.assertEqual(impression_stored['i']['t'], 'on') - self.assertEqual(impression_stored['i']['r'], 'label1') - self.assertEqual(impression_stored['i']['c'], 123456) - self.assertEqual(impression_stored['i']['b'], 'some_bucketing_key') - self.assertEqual(impression_stored['i']['m'], 123456) - - def test_ttl_called_once(self): - self._some_config = mock.MagicMock() - - self._redis = get_redis({'redisPrefix': 'singleQueueTests'}) - self._redis.delete(IMPRESSIONS_QUEUE_KEY) - - self._config = { - 'ready': 180000, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'singleQueueTests' - } - - self._factory = get_factory('asdqwe123456', config=self._config) - - impression = Impression(matching_key='some_matching_key', - feature_name='some_feature_name', - treatment='on', - label='label1', - change_number=123456, - bucketing_key='some_bucketing_key', - time=123456) - - self.an_impressions_cache = RedisImpressionsCache(self._redis) - - self.an_impressions_cache._logger.debug = mock.MagicMock() - logger_debug = self.an_impressions_cache._logger.debug - - self.an_impressions_cache.add_impressions([impression]) - self.an_impressions_cache.add_impressions([impression]) - - logger_debug \ - .assert_called_once_with("SET EXPIRE KEY FOR QUEUE") - - def test_impression_added_custom_machine_name_and_ip(self): - self._some_config = mock.MagicMock() - - self._redis = get_redis({'redisPrefix': 'singleQueueTests2'}) - self._redis.delete(IMPRESSIONS_QUEUE_KEY) - - self._config = { - 'ready': 180000, - 'redisDb': 0, - 'redisHost': 'localhost', - 'redisPosrt': 6379, - 'redisPrefix': 'singleQueueTests', - 'splitSdkMachineName': 'CustomMachineName', - 'splitSdkMachineIp': '1.2.3.4' - } - - self._factory = get_factory('asdqwe123456', config=self._config) - - impression = Impression(matching_key='some_matching_key', - feature_name='some_feature_name', - treatment='on', - label='label1', - change_number=123456, - bucketing_key='some_bucketing_key', - time=123456) - - self.an_impressions_cache = RedisImpressionsCache(self._redis) - - self.an_impressions_cache._logger.debug = mock.MagicMock() - logger_debug = self.an_impressions_cache._logger.debug - - self.an_impressions_cache.add_impressions([impression]) - - logger_debug \ - .assert_called_once_with("SET EXPIRE KEY FOR QUEUE") - - # Assert that the TTL is within a 10-second range (between it was set and retrieved). - ttl = self._redis.ttl(IMPRESSIONS_QUEUE_KEY) - - self.assertLessEqual(int(ttl), IMPRESSION_KEY_DEFAULT_TTL) - self.assertGreaterEqual(int(ttl), IMPRESSION_KEY_DEFAULT_TTL - 10) - - impression_stored = decode(self._redis.rpop(IMPRESSIONS_QUEUE_KEY)) - self.assertEqual(impression_stored['m']['i'], '1.2.3.4') - self.assertEqual(impression_stored['m']['s'], SDK_VERSION) - self.assertEqual(impression_stored['m']['n'], 'CustomMachineName') - self.assertEqual(impression_stored['i']['k'], 'some_matching_key') - self.assertEqual(impression_stored['i']['f'], 'some_feature_name') - self.assertEqual(impression_stored['i']['t'], 'on') - self.assertEqual(impression_stored['i']['r'], 'label1') - self.assertEqual(impression_stored['i']['c'], 123456) - self.assertEqual(impression_stored['i']['b'], 'some_bucketing_key') - self.assertEqual(impression_stored['i']['m'], 123456) diff --git a/splitio/tests/test_input_validator.py b/splitio/tests/test_input_validator.py deleted file mode 100644 index 1ee2acf7..00000000 --- a/splitio/tests/test_input_validator.py +++ /dev/null @@ -1,776 +0,0 @@ -"""Unit tests for the input_validator module""" -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase -from splitio.brokers import RedisBroker, SelfRefreshingBroker, UWSGIBroker -from splitio.clients import Client -from splitio.treatments import CONTROL -from splitio.redis_support import get_redis -from splitio.splits import Split -from splitio import input_validator -from splitio.managers import RedisSplitManager, SelfRefreshingSplitManager, UWSGISplitManager -from splitio.key import Key -from splitio.uwsgi import UWSGICacheEmulator -from splitio import get_factory - - -class TestInputSanitizationGetTreatment(TestCase): - - def setUp(self): - self.some_config = mock.MagicMock() - self.some_api_key = mock.MagicMock() - self.redis = get_redis({'redisPrefix': 'test'}) - self.client = Client(RedisBroker(self.redis, self.some_config)) - self.client._broker.fetch_feature = mock.MagicMock(return_value=Split( - "some_feature", - 0, - False, - "default_treatment", - "user", - "ACTIVE", - 123 - )) - - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - input_validator._LOGGER.warning = mock.MagicMock() - self.logger_warning = input_validator._LOGGER.warning - - def test_get_treatment_with_null_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - None, "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed a null key, key must be a" + - " non-empty string.") - - def test_get_treatment_with_empty_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "", "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an empty key, key must be a" + - " non-empty string.") - - def test_get_treatment_with_length_key(self): - key = "" - for x in range(0, 255): - key = key + "a" - self.assertEqual(CONTROL, self.client.get_treatment(key, "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: key too long - must be 250 characters or " + - "less.") - - def test_get_treatment_with_number_key(self): - self.assertEqual("default_treatment", self.client.get_treatment( - 12345, "some_feature")) - self.logger_warning \ - .assert_called_once_with("get_treatment: key 12345 is not of type string, converting.") - - def test_get_treatment_with_nan_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - float("nan"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_get_treatment_with_inf_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - float("inf"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_get_treatment_with_bool_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - True, "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_get_treatment_with_array_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - [], "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_get_treatment_with_null_feature_name(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "some_key", None)) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed a null feature_name, " + - "feature_name must be a non-empty string.") - - def test_get_treatment_with_numeric_feature_name(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "some_key", 12345)) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid feature_name, " + - "feature_name must be a non-empty string.") - - def test_get_treatment_with_bool_feature_name(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "some_key", True)) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid feature_name, " + - "feature_name must be a non-empty string.") - - def test_get_treatment_with_array_feature_name(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "some_key", [])) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid feature_name, " + - "feature_name must be a non-empty string.") - - def test_get_treatment_with_empty_feature_name(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "some_key", "")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an empty feature_name, " + - "feature_name must be a non-empty string.") - - def test_get_treatment_with_valid_inputs(self): - self.assertEqual("default_treatment", self.client.get_treatment( - "some_key", "some_feature")) - self.logger_error.assert_not_called() - self.logger_warning.assert_not_called() - - def test_get_treatment_with_null_matching_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key(None, "bucketing_key"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed a null matching_key, " + - "matching_key must be a non-empty string.") - - def test_get_treatment_with_empty_matching_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key("", "bucketing_key"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an empty matching_key, " + - "matching_key must be a non-empty string.") - - def test_get_treatment_with_nan_matching_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key(float("nan"), "bucketing_key"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid matching_key, " + - "matching_key must be a non-empty string.") - - def test_get_treatment_with_inf_matching_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key(float("inf"), "bucketing_key"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid matching_key, " + - "matching_key must be a non-empty string.") - - def test_get_treatment_with_bool_matching_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key(True, "bucketing_key"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid matching_key, " + - "matching_key must be a non-empty string.") - - def test_get_treatment_with_array_matching_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key([], "bucketing_key"), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid matching_key, " + - "matching_key must be a non-empty string.") - - def test_get_treatment_with_numeric_matching_key(self): - self.assertEqual("default_treatment", self.client.get_treatment( - Key(12345, "bucketing_key"), "some_feature")) - self.logger_warning \ - .assert_called_once_with("get_treatment: matching_key 12345 is not of type string, " - "converting.") - - def test_get_treatment_with_length_matching_key(self): - key = "" - for x in range(0, 255): - key = key + "a" - self.assertEqual(CONTROL, self.client.get_treatment(Key(key, "bucketing_key"), - "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: matching_key too long - must be 250 " + - "characters or less.") - - def test_get_treatment_with_null_bucketing_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key("matching_key", None), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed a null bucketing_key, " + - "bucketing_key must be a non-empty string.") - - def test_get_treatment_with_bool_bucketing_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key("matching_key", True), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid bucketing_key, " + - "bucketing_key must be a non-empty string.") - - def test_get_treatment_with_array_bucketing_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key("matching_key", []), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an invalid bucketing_key, " + - "bucketing_key must be a non-empty string.") - - def test_get_treatment_with_empty_bucketing_key(self): - self.assertEqual(CONTROL, self.client.get_treatment( - Key("matching_key", ""), "some_feature")) - self.logger_error \ - .assert_called_once_with("get_treatment: you passed an empty bucketing_key, " + - "bucketing_key must be a non-empty string.") - - def test_get_treatment_with_numeric_bucketing_key(self): - self.assertEqual("default_treatment", self.client.get_treatment( - Key("matching_key", 12345), "some_feature")) - self.logger_warning \ - .assert_called_once_with("get_treatment: bucketing_key 12345 is not of type string, " - "converting.") - - def test_get_treatment_with_invalid_attributes(self): - self.assertEqual(CONTROL, self.client.get_treatment( - "some_key", "some_feature", True)) - self.logger_error \ - .assert_called_once_with("get_treatment: attributes must be of type dictionary.") - - def test_get_treatment_with_valid_attributes(self): - attributes = { - "test": "test" - } - self.assertEqual("default_treatment", self.client.get_treatment( - "some_key", "some_feature", attributes)) - - def test_get_treatment_with_none_attributes(self): - self.assertEqual("default_treatment", self.client.get_treatment( - "some_key", "some_feature", None)) - - def test_get_treatment_with_whitespaces(self): - self.assertEqual("default_treatment", self.client.get_treatment( - "some_key", " some_feature ")) - self.logger_warning \ - .assert_called_once_with("get_treatment: feature_name ' some_feature ' has extra" + - " whitespace, trimming.") - - def test_get_treatment_with_whitespaces_2(self): - self.assertEqual("default_treatment", self.client.get_treatment( - "some_key", "some_feature ")) - self.logger_warning \ - .assert_called_once_with("get_treatment: feature_name 'some_feature ' has extra" + - " whitespace, trimming.") - - -class TestInputSanitizationTrack(TestCase): - - def setUp(self): - self.some_config = mock.MagicMock() - self.some_api_key = mock.MagicMock() - self.redis = get_redis({'redisPrefix': 'test'}) - self.client = Client(RedisBroker(self.redis, self.some_config)) - - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - input_validator._LOGGER.warning = mock.MagicMock() - self.logger_warning = input_validator._LOGGER.warning - - def test_track_with_null_key(self): - self.assertEqual(False, self.client.track( - None, "traffic_type", "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed a null key, key must be a" + - " non-empty string.") - - def test_track_with_empty_key(self): - self.assertEqual(False, self.client.track( - "", "traffic_type", "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an empty key, key must be a" + - " non-empty string.") - - def test_track_with_numeric_key(self): - self.assertEqual(True, self.client.track( - 12345, "traffic_type", "event_type", 1)) - self.logger_warning \ - .assert_called_once_with("track: key 12345 is not of type string," - " converting.") - - def test_track_with_bool_key(self): - self.assertEqual(False, self.client.track( - True, "traffic_type", "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_track_with_array_key(self): - self.assertEqual(False, self.client.track( - [], "traffic_type", "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_track_with_length_key(self): - key = "" - for x in range(0, 255): - key = key + "a" - self.assertEqual(False, self.client.track( - key, "traffic_type", "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: key too long - must be 250 characters or " + - "less.") - - def test_track_with_null_traffic_type(self): - self.assertEqual(False, self.client.track( - "some_key", None, "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed a null traffic_type, traffic_type" + - " must be a non-empty string.") - - def test_track_with_bool_traffic_type(self): - self.assertEqual(False, self.client.track( - "some_key", True, "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid traffic_type, traffic_type" + - " must be a non-empty string.") - - def test_track_with_array_traffic_type(self): - self.assertEqual(False, self.client.track( - "some_key", [], "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid traffic_type, traffic_type" + - " must be a non-empty string.") - - def test_track_with_numeric_traffic_type(self): - self.assertEqual(False, self.client.track( - "some_key", 12345, "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid traffic_type, traffic_type" + - " must be a non-empty string.") - - def test_track_with_empty_traffic_type(self): - self.assertEqual(False, self.client.track( - "some_key", "", "event_type", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an empty traffic_type, traffic_type" + - " must be a non-empty string.") - - def test_track_with_lowercase_traffic_type(self): - self.assertEqual(True, self.client.track( - "some_key", "TRAFFIC_type", "event_type", 1)) - self.logger_warning \ - .assert_called_once_with("track: TRAFFIC_type should be all lowercase -" + - " converting string to lowercase.") - - def test_track_with_null_event_type(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", None, 1)) - self.logger_error \ - .assert_called_once_with("track: you passed a null event_type, event_type" + - " must be a non-empty string.") - - def test_track_with_empty_event_type(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", "", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an empty event_type, event_type" + - " must be a non-empty string.") - - def test_track_with_bool_event_type(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", True, 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid event_type, event_type" + - " must be a non-empty string.") - - def test_track_with_array_event_type(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", [], 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid event_type, event_type" + - " must be a non-empty string.") - - def test_track_with_numeric_event_type(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", 12345, 1)) - self.logger_error \ - .assert_called_once_with("track: you passed an invalid event_type, event_type" + - " must be a non-empty string.") - - def test_track_with_event_type_does_not_conform_reg_exp(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", "@@", 1)) - self.logger_error \ - .assert_called_once_with("track: you passed @@, event_type must adhere to the regular " - "expression ^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$. This means " - "an event name must be alphanumeric, cannot be more than 80 " - "characters long, and can only include a dash, underscore, " - "period, or colon as separators of alphanumeric characters.") - - def test_track_with_null_value(self): - self.assertEqual(True, self.client.track( - "some_key", "traffic_type", "event_type", None)) - self.logger_error.assert_not_called() - self.logger_warning.assert_not_called() - - def test_track_with_string_value(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", "event_type", "test")) - self.logger_error \ - .assert_called_once_with("track: value must be a number.") - - def test_track_with_bool_value(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", "event_type", True)) - self.logger_error \ - .assert_called_once_with("track: value must be a number.") - - def test_track_with_array_value(self): - self.assertEqual(False, self.client.track( - "some_key", "traffic_type", "event_type", [])) - self.logger_error \ - .assert_called_once_with("track: value must be a number.") - - def test_track_with_int_value(self): - self.assertEqual(True, self.client.track( - "some_key", "traffic_type", "event_type", 1)) - self.logger_error.assert_not_called() - self.logger_warning.assert_not_called() - - def test_track_with_float_value(self): - self.assertEqual(True, self.client.track( - "some_key", "traffic_type", "event_type", 1.3)) - self.logger_error.assert_not_called() - self.logger_warning.assert_not_called() - - -class TestInputSanitizationRedisManager(TestCase): - - def setUp(self): - self.some_config = mock.MagicMock() - self.some_api_key = mock.MagicMock() - self.redis = get_redis({'redisPrefix': 'test'}) - self.client = Client(RedisBroker(self.redis, self.some_config)) - - self.manager = RedisSplitManager(self.client._broker) - - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - - def test_manager_with_null_feature_name(self): - self.assertEqual(None, self.manager.split(None)) - self.logger_error \ - .assert_called_once_with("split: you passed a null feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_empty_feature_name(self): - self.assertEqual(None, self.manager.split("")) - self.logger_error \ - .assert_called_once_with("split: you passed an empty feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_bool_feature_name(self): - self.assertEqual(None, self.manager.split(True)) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_array_feature_name(self): - self.assertEqual(None, self.manager.split([])) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_numeric_feature_name(self): - self.assertEqual(None, self.manager.split(12345)) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_valid_feature_name(self): - self.assertEqual(None, self.manager.split("valid_feature_name")) - self.logger_error.assert_not_called() - - -class TestInputSanitizationSelfRefreshingManager(TestCase): - - def setUp(self): - self.some_api_key = mock.MagicMock() - self.broker = SelfRefreshingBroker(self.some_api_key) - self.client = Client(self.broker) - self.manager = SelfRefreshingSplitManager(self.broker) - - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - - def test_manager_with_null_feature_name(self): - self.assertEqual(None, self.manager.split(None)) - self.logger_error \ - .assert_called_once_with("split: you passed a null feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_empty_feature_name(self): - self.assertEqual(None, self.manager.split("")) - self.logger_error \ - .assert_called_once_with("split: you passed an empty feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_bool_feature_name(self): - self.assertEqual(None, self.manager.split(True)) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_array_feature_name(self): - self.assertEqual(None, self.manager.split([])) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_numeric_feature_name(self): - self.assertEqual(None, self.manager.split(12345)) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_valid_feature_name(self): - self.assertEqual(None, self.manager.split("valid_feature_name")) - self.logger_error.assert_not_called() - - -class TestInputSanitizationUWSGIManager(TestCase): - - def setUp(self): - self.some_api_key = mock.MagicMock() - self.uwsgi = UWSGICacheEmulator() - self.broker = UWSGIBroker(self.uwsgi, {'eventsQueueSize': 30}) - self.client = Client(self.broker) - self.manager = UWSGISplitManager(self.broker) - - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - - def test_manager_with_null_feature_name(self): - self.assertEqual(None, self.manager.split(None)) - self.logger_error \ - .assert_called_once_with("split: you passed a null feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_empty_feature_name(self): - self.assertEqual(None, self.manager.split("")) - self.logger_error \ - .assert_called_once_with("split: you passed an empty feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_bool_feature_name(self): - self.assertEqual(None, self.manager.split(True)) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_array_feature_name(self): - self.assertEqual(None, self.manager.split([])) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_numeric_feature_name(self): - self.assertEqual(None, self.manager.split(12345)) - self.logger_error \ - .assert_called_once_with("split: you passed an invalid feature_name, feature_name" + - " must be a non-empty string.") - - def test_manager_with_valid_feature_name(self): - self.assertEqual(None, self.manager.split("valid_feature_name")) - self.logger_error.assert_not_called() - - -class TestInputSanitizationGetTreatments(TestCase): - - def setUp(self): - self.some_config = mock.MagicMock() - self.some_api_key = mock.MagicMock() - self.redis = get_redis({'redisPrefix': 'test'}) - self.client = Client(RedisBroker(self.redis, self.some_config)) - - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - input_validator._LOGGER.warning = mock.MagicMock() - self.logger_warning = input_validator._LOGGER.warning - - def test_get_treatments_with_null_key(self): - expected = { - "some_feature": "control" - } - self.assertEqual(expected, self.client.get_treatments( - None, ["some_feature"])) - self.logger_error \ - .assert_called_once_with("get_treatments: you passed a null key, key must be a" + - " non-empty string.") - - def test_get_treatments_with_empty_key(self): - expected = { - "some_feature": "control" - } - self.assertEqual(expected, self.client.get_treatments( - "", ["some_feature"])) - self.logger_error \ - .assert_called_once_with("get_treatments: you passed an empty key, key must be a" + - " non-empty string.") - - def test_get_treatments_with_length_key(self): - key = "" - for x in range(0, 255): - key = key + "a" - expected = { - "some_feature": "control" - } - self.assertEqual(expected, self.client.get_treatments(key, ["some_feature"])) - self.logger_error \ - .assert_called_once_with("get_treatments: key too long - must be 250 characters or " + - "less.") - - def test_get_treatments_with_number_key(self): - self.assertEqual({"some_feature": "control"}, self.client.get_treatments( - 12345, ["some_feature"])) - self.logger_warning \ - .assert_called_once_with("get_treatments: key 12345 is not of type string, converting.") - - def test_get_treatments_with_bool_key(self): - expected = { - "some_feature": "control" - } - self.assertEqual(expected, self.client.get_treatments( - True, ["some_feature"])) - self.logger_error \ - .assert_called_once_with("get_treatments: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_get_treatments_with_array_key(self): - expected = { - "some_feature": "control" - } - self.assertEqual(expected, self.client.get_treatments( - [], ["some_feature"])) - self.logger_error \ - .assert_called_once_with("get_treatments: you passed an invalid key, key must be a" + - " non-empty string.") - - def test_get_treatments_with_null_features(self): - self.assertEqual({}, self.client.get_treatments("some_key", None)) - self.logger_error \ - .assert_called_once_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_bool_type_of_features(self): - self.assertEqual({}, self.client.get_treatments("some_key", True)) - self.logger_error \ - .assert_called_once_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_string_type_of_features(self): - self.assertEqual({}, self.client.get_treatments("some_key", "some_string")) - self.logger_error \ - .assert_called_once_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_empty_features(self): - self.assertEqual({}, self.client.get_treatments("some_key", [])) - self.logger_error \ - .assert_called_once_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_none_features(self): - self.assertEqual({}, self.client.get_treatments("some_key", [None, None])) - self.logger_error \ - .assert_called_once_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_invalid_type_of_features(self): - self.assertEqual({}, self.client.get_treatments("some_key", [True])) - self.logger_error \ - .assert_called_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_empty_features_array(self): - self.assertEqual({}, self.client.get_treatments("some_key", ["", ""])) - self.logger_error \ - .assert_called_with("get_treatments: feature_names must be a non-empty array.") - - def test_get_treatments_with_whitespaces(self): - expected = { - "some": "control" - } - self.assertEqual(expected, self.client.get_treatments("some_key", [" some"])) - self.logger_warning \ - .assert_called_once_with("get_treatments: feature_name ' some' has extra whitespace," - + " trimming.") - - def test_get_treatments_with_whitespaces_2(self): - expected = { - "some": "control", - "another": "control" - } - self.assertEqual(expected, self.client.get_treatments("some_key", [" some ", "another"])) - self.logger_warning \ - .assert_called_once_with("get_treatments: feature_name ' some ' has extra whitespace," - + " trimming.") - - -class TestInputSanitizationFactory(TestCase): - - def setUp(self): - input_validator._LOGGER.error = mock.MagicMock() - self.logger_error = input_validator._LOGGER.error - - def test_factory_with_null_apikey(self): - self.assertEqual(None, get_factory(None)) - self.logger_error \ - .assert_called_once_with("factory_instantiation: you passed a null apikey, apikey" + - " must be a non-empty string.") - - def test_factory_with_empty_apikey(self): - self.assertEqual(None, get_factory('')) - self.logger_error \ - .assert_called_once_with("factory_instantiation: you passed an empty apikey, apikey" + - " must be a non-empty string.") - - def test_factory_with_invalid_apikey(self): - self.assertEqual(None, get_factory(True)) - self.logger_error \ - .assert_called_once_with("factory_instantiation: you passed an invalid apikey, apikey" + - " must be a non-empty string.") - - def test_factory_with_invalid_apikey_redis(self): - config = { - 'redisDb': 0, - 'redisHost': 'localhost' - } - self.assertNotEqual(None, get_factory(True, config=config)) - self.logger_error.assert_not_called() - - def test_factory_with_invalid_config(self): - config = { - 'some': 0 - } - self.assertEqual(None, get_factory("apikey", config=config)) - self.logger_error \ - .assert_called_once_with('no ready parameter has been set - incorrect control ' - + 'treatments could be logged') - - def test_factory_with_invalid_null_ready(self): - config = { - 'ready': None - } - self.assertEqual(None, get_factory("apikey", config=config)) - self.logger_error \ - .assert_called_once_with('no ready parameter has been set - incorrect control ' - + 'treatments could be logged') - - def test_factory_with_invalid_ready(self): - config = { - 'ready': True - } - self.assertEqual(None, get_factory("apikey", config=config)) - self.logger_error \ - .assert_called_once_with('no ready parameter has been set - incorrect control ' - + 'treatments could be logged') diff --git a/splitio/tests/test_matchers.py b/splitio/tests/test_matchers.py deleted file mode 100644 index d65db4ed..00000000 --- a/splitio/tests/test_matchers.py +++ /dev/null @@ -1,1125 +0,0 @@ -''' -Unit tests for the matchers module - -''' -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - - -import os.path -import json -from unittest import TestCase - -from splitio.matchers import AndCombiner, CombiningMatcher, AllKeysMatcher, \ - NegatableMatcher, AttributeMatcher, BetweenMatcher, DateTimeBetweenMatcher,\ - NumberBetweenMatcher, DataType, EqualToCompareMixin, \ - GreaterOrEqualToCompareMixin, LessThanOrEqualToCompareMixin, \ - CompareMatcher, UserDefinedSegmentMatcher, WhitelistMatcher, \ - DateEqualToMatcher, NumberEqualToMatcher, EqualToMatcher, \ - DateTimeGreaterThanOrEqualToMatcher, NumberGreaterThanOrEqualToMatcher, \ - GreaterThanOrEqualToMatcher, DateTimeLessThanOrEqualToMatcher, \ - NumberLessThanOrEqualToMatcher, LessThanOrEqualToMatcher, \ - StartsWithMatcher, EndsWithMatcher, ContainsStringMatcher, \ - ContainsAllOfSetMatcher, ContainsAnyOfSetMatcher, PartOfSetMatcher, \ - EqualToSetMatcher, DependencyMatcher, BooleanMatcher, RegexMatcher - -from splitio.transformers import AsDateHourMinuteTimestampTransformMixin, \ - AsNumberTransformMixin, AsDateTimestampTransformMixin -from splitio.tests.utils import MockUtilsMixin -from splitio.splits import SplitParser - - -class AndCombinerTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_attributes = mock.MagicMock() - self.some_client = mock.MagicMock() - self.combiner = AndCombiner() - - def test_combine_returns_false_on_none_matchers(self): - ''' - Tests that combine returns false if matchers is None - ''' - self.assertFalse( - self.combiner.combine(None, self.some_key, self.some_attributes) - ) - - def test_combine_returns_false_on_empty_matchers(self): - ''' - Tests that combine returns false if matchers is empty - ''' - self.assertFalse( - self.combiner.combine([], self.some_key, self.some_attributes) - ) - - def test_combine_calls_match_on_all_matchers(self): - ''' - Tests that combine calls match on all matchers - ''' - matchers = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - - # We set all return values of match to True to avoid short circuiting - for matcher in matchers: - matcher.match.return_value = True - - self.combiner.combine(matchers, self.some_key, self.some_attributes, self.some_client) - - for matcher in matchers: - matcher.match.assert_called_once_with( - self.some_key, self.some_attributes, self.some_client - ) - - def test_combine_short_circuits_check(self): - ''' - Tests that combine stops checking after the first false - ''' - matchers = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - - # We set the second return value of match to False to short-circuit the - # check - matchers[0].match.return_value = True - matchers[1].match.return_value = False - - self.combiner.combine(matchers, self.some_key, self.some_attributes, self.some_client) - - matchers[0].match.assert_called_once_with( - self.some_key, self.some_attributes, self.some_client - ) - matchers[1].match.assert_called_once_with( - self.some_key, self.some_attributes, self.some_client - ) - matchers[2].match.assert_not_called() - - def test_returns_result_of_calling_all(self): - ''' - Tests that combine stops checking after the first false - ''' - matchers = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - all_mock = self.patch_builtin('all') - self.assertEqual( - all_mock.return_value, - self.combiner.combine(matchers, self.some_key, self.some_attributes) - ) - - -class CombiningMatcherTests(TestCase): - def setUp(self): - self.some_combiner = mock.MagicMock() - self.some_delegates = [ - mock.MagicMock(), mock.MagicMock(), mock.MagicMock() - ] - self.some_key = mock.MagicMock() - self.some_attributes = mock.MagicMock() - - self.matcher = CombiningMatcher(self.some_combiner, self.some_delegates) - - def test_match_call_combiner_combine(self): - ''' - Tests that match calls combine() on the combiner - ''' - self.matcher.match(self.some_key, self.some_attributes) - - self.assertEqual(1, self.some_combiner.combine.call_count) - args, _ = self.some_combiner.combine.call_args - - self.assertListEqual(list(args[0]), list(self.some_delegates)) - self.assertEqual(args[1], self.some_key) - self.assertEqual(args[2], self.some_attributes) - - def test_match_returns_combiner_combine_result(self): - ''' - Tests that match returns the result of the combiner combine() method - ''' - self.assertEqual( - self.some_combiner.combine.return_value, - self.matcher.match(self.some_key, self.some_attributes) - ) - - -class AllKeysMatcherTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - - self.matcher = AllKeysMatcher() - - def test_match_returns_true_if_key_is_not_none(self): - ''' - Tests that match returns True if the key is not None - ''' - self.assertTrue(self.matcher.match(self.some_key)) - - def test_match_returns_false_if_key_is_none(self): - ''' - Tests that match returns False if the key is not None - ''' - self.assertFalse(self.matcher.match(None)) - - -class NegatableMatcherTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_delegate = mock.MagicMock() - self.some_client = mock.MagicMock() - self.some_attributes = mock.MagicMock() - - def test_match_calls_delegate_match(self): - ''' - Tests that match calls the delegate match method - ''' - matcher = NegatableMatcher(True, self.some_delegate) - - matcher.match(self.some_key, self.some_attributes, self.some_client) - - self.some_delegate.match.assert_called_once_with(self.some_key, self.some_attributes, - self.some_client) - - def test_if_negate_true_match_negates_result_of_delegate_match(self): - ''' - Tests that if negate is True, match negates the result of the delegate - match - ''' - matcher = NegatableMatcher(True, self.some_delegate) - - self.some_delegate.match.return_value = True - self.assertFalse(matcher.match(self.some_key)) - - self.some_delegate.match.return_value = False - self.assertTrue(matcher.match(self.some_key)) - - def test_if_negate_false_match_doesnt_negate_result_of_delegate_match(self): - ''' - Tests that if negate is False, match doesn't negates the result of the - delegate match - ''' - matcher = NegatableMatcher(False, self.some_delegate) - - self.some_delegate.match.return_value = True - self.assertTrue(matcher.match(self.some_key)) - - self.some_delegate.match.return_value = False - self.assertFalse(matcher.match(self.some_key)) - - -class AttributeMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.negatable_matcher_mock = ( - self.patch('splitio.matchers.NegatableMatcher').return_value - ) - self.some_attribute = mock.MagicMock() - self.some_key = mock.MagicMock() - self.some_client = mock.MagicMock() - self.some_attribute_value = mock.MagicMock() - self.some_attributes = mock.MagicMock() - self.some_attributes.__contains__.return_value = True - self.some_attributes.__getitem__.return_value = self.some_attribute_value - - self.some_matcher = mock.MagicMock() - self.some_negate = mock.MagicMock() - - self.matcher = AttributeMatcher( - self.some_attribute, self.some_matcher, self.some_negate - ) - - def test_match_calls_negatable_matcher_match_with_key_if_attribute_is_none(self): - ''' - Tests that match calls the negatable matcher match method with the - supplied key if attribute is None - ''' - matcher = AttributeMatcher(None, self.some_matcher, self.some_negate) - matcher.match(self.some_key, self.some_attributes, self.some_client) - - self.negatable_matcher_mock.match.assert_called_once_with(self.some_key, - self.some_attributes, - self.some_client) - - def test_match_returns_false_attributes_is_none(self): - ''' - Tests that match returns False if attributes is None - ''' - self.assertFalse(self.matcher.match(self.some_key, None)) - - def test_match_returns_false_attribute_is_not_in_attributes(self): - ''' - Tests that match returns False if the attribute is not in the attributes - dictionary - ''' - self.some_attributes.__contains__.return_value = None - self.assertFalse( - self.matcher.match(self.some_key, self.some_attributes) - ) - - def test_match_returns_false_attribute_value_is_none(self): - ''' - Tests that match returns False if the value of the attribute is None - ''' - self.some_attributes.__getitem__.return_value = None - self.assertFalse( - self.matcher.match(self.some_key, self.some_attributes) - ) - - def test_match_calls_negatable_matcher_match_with_attribute_value(self): - ''' - Tests that match calls match on the negatable matcher is the attribute - value as key - ''' - self.matcher.match(self.some_key, self.some_attributes) - self.negatable_matcher_mock.match.assert_called_once_with( - self.some_attribute_value - ) - - def test_match_returns_result_negatable_matcher_match(self): - ''' - Tests that match returns the result of invoking match on the negatable - matcher - ''' - self.assertEqual( - self.negatable_matcher_mock.match.return_value, - self.matcher.match(self.some_key, self.some_attributes) - ) - - -class BetweenMatcherForDataTypeTests(TestCase): - def test_for_data_type_returns_date_time_between_batcher_for_datetime(self): - ''' - Tests that for_data_type returns a DateTimeBetweenMatcher matcher with - the DataType.DATETIME data type - ''' - matcher = BetweenMatcher.for_data_type( - DataType.DATETIME, 1461601825000, 1461609025000 - ) - self.assertIsInstance(matcher, DateTimeBetweenMatcher) - - def test_for_data_type_returns_number_between_batcher_for_number(self): - ''' - Tests that for_data_type returns a NumberBetweenMatcher matcher with the - DataType.NUMBER data type - ''' - matcher = BetweenMatcher.for_data_type(DataType.NUMBER, 1, 100) - self.assertIsInstance(matcher, NumberBetweenMatcher) - - -class BetweenMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_start = mock.MagicMock() - self.some_start.__le__.return_value = True - self.some_end = mock.MagicMock() - self.transformed_key = mock.MagicMock() - self.transformed_key.__le__.return_value = True - self.some_data_type = mock.MagicMock() - - self.transform_mock = self.patch( - 'splitio.matchers.BetweenMatcher.transform_key', - return_value=self.transformed_key - ) - - self.matcher = BetweenMatcher( - self.some_start, self.some_end, self.some_data_type - ) - - def test_match_calls_transform_on_key(self): - ''' - Tests that match calls transform on key - ''' - self.matcher.match(self.some_key) - - self.transform_mock.assert_called_once_with(self.some_key) - - def test_match_returns_none_if_transform_returns_none(self): - ''' - Tests that match returns None if transform returns None - ''' - self.transform_mock.side_effect = None - self.transform_mock.return_value = None - - self.assertIsNone(self.matcher.match(self.some_key)) - - def test_match_checks_transformed_key_between_start_and_end(self): - ''' - Tests that match checks that the transformed key is between the start - and end - ''' - self.matcher.match(self.some_key) - - self.some_start.__le__.assert_called_once_with(self.transformed_key) - self.transformed_key.__le__.assert_called_once_with(self.some_end) - - def test_match_returns_true_if_key_between_start_and_end(self): - ''' - Tests that match returns True if key is between the start and end - ''' - self.assertTrue(self.matcher.match(self.some_key)) - - def test_match_returns_false_if_key_less_than_start(self): - ''' - Tests that match returns False if key is less than start - ''' - self.some_start.__le__.return_value = False - self.assertFalse(self.matcher.match(self.some_key)) - - def test_match_returns_false_if_key_greater_than_end(self): - ''' - Tests that match returns False if key is greater than end - ''' - self.transformed_key.__le__.return_value = False - self.assertFalse(self.matcher.match(self.some_key)) - - -class EqualToCompareMixinTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_key.__eq__.return_value = True - self.some_value = mock.MagicMock() - self.mixin = EqualToCompareMixin() - - def test_compare_checks_if_key_is_equal_to_value(self): - ''' - Tests that compare checks if the key and the value are equal - ''' - self.mixin.compare(self.some_key, self.some_value) - self.some_key.__eq__.assert_called_once_with(self.some_value) - - def test_compare_returns_true_if_key_and_value_are_equal(self): - ''' - Tests that compare returns True the key and the value are equal - ''' - self.assertTrue(self.mixin.compare(self.some_key, self.some_value)) - - def test_compare_returns_false_if_key_and_value_are_not_equal(self): - ''' - Tests that compare returns False if the key and the value are not equal - ''' - self.some_key.__eq__.return_value = False - self.assertFalse(self.mixin.compare(self.some_key, self.some_value)) - - -class GreaterOrEqualToCompareMixinTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_key.__ge__.return_value = True - self.some_value = mock.MagicMock() - self.mixin = GreaterOrEqualToCompareMixin() - - def test_compare_checks_if_key_is_greater_than_or_equal_to_value(self): - ''' - Tests that compare checks if the key is greater than or equal to value - ''' - self.mixin.compare(self.some_key, self.some_value) - self.some_key.__ge__.assert_called_once_with(self.some_value) - - def test_compare_returns_true_if_key_greater_than_or_equal_to_value(self): - ''' - Tests that compare returns True if the key is greater than or equal to - value - ''' - self.assertTrue(self.mixin.compare(self.some_key, self.some_value)) - - def test_compare_returns_false_if_key_not_greater_than_or_equal_to_value(self): - ''' - Tests that compare returns True if the key is not greater than or equal - to value - ''' - self.some_key.__ge__.return_value = False - self.assertFalse(self.mixin.compare(self.some_key, self.some_value)) - - -class LessThanOrEqualToCompareMixinTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_key.__le__.return_value = True - self.some_value = mock.MagicMock() - self.mixin = LessThanOrEqualToCompareMixin() - - def test_compare_checks_if_key_is_less_than_or_equal_to_value(self): - ''' - Tests that compare checks if the key is less than or equal to value - ''' - self.mixin.compare(self.some_key, self.some_value) - self.some_key.__le__.assert_called_once_with(self.some_value) - - def test_compare_returns_true_if_key_less_than_or_equal_to_value(self): - ''' - Tests that compare returns True if the key is less than or equal to - value - ''' - self.assertTrue(self.mixin.compare(self.some_key, self.some_value)) - - def test_compare_returns_false_if_key_not_less_than_or_equal_to_value(self): - ''' - Tests that compare returns True if the key is not less than or equal to - value - ''' - self.some_key.__le__.return_value = False - self.assertFalse(self.mixin.compare(self.some_key, self.some_value)) - - -class CompareMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_compare_to = mock.MagicMock() - self.transformed_key = mock.MagicMock() - self.some_data_type = mock.MagicMock() - - self.transform_mock = self.patch( - 'splitio.matchers.CompareMatcher.transform_key', - return_value=self.transformed_key - ) - self.matcher = CompareMatcher(self.some_compare_to, self.some_data_type) - self.compare_mock = self.patch_object(self.matcher, 'compare') - - def test_match_calls_transform_on_the_key(self): - ''' - Tests that match calls transform on the supplied key - ''' - self.matcher.match(self.some_key) - - self.transform_mock.assert_called_once_with(self.some_key) - - def test_match_calls_returns_none_if_transformed_key_is_none(self): - ''' - Tests that match returns None if the transformed key is None - ''' - self.transform_mock.side_effect = None - self.transform_mock.return_value = None - - self.assertIsNone(self.matcher.match(self.some_key)) - - def test_match_calls_compare_on_transformed_key_and_compare_to(self): - ''' - Tests that match calls compare with the transformed key and the - compare_to value - ''' - self.matcher.match(self.some_key) - - self.compare_mock.assert_called_once_with(self.transformed_key, - self.some_compare_to) - - def test_match_returns_compare_result(self): - ''' - Tests that match returns the result of running compare with the - transformed key and the compare_to value - ''' - self.assertEqual( - self.compare_mock.return_value, self.matcher.match(self.some_key) - ) - - -class UserDefinedSegmentMatcherTests(TestCase): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_segment = mock.MagicMock() - self.matcher = UserDefinedSegmentMatcher(self.some_segment) - - def test_match_calls_contain_in_segment(self): - ''' - Tests that match calls contains on the associated segment - ''' - self.matcher.match(self.some_key) - - self.some_segment.contains.assert_called_once_with(self.some_key) - - def test_match_returns_result_of_contains(self): - ''' - Tests that match returns the result of calling contains on the - associated segment - ''' - self.assertEqual(self.some_segment.contains.return_value, - self.matcher.match(self.some_key)) - - -class WhitelistMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - - self.some_whitelist = mock.MagicMock() - self.whitelist_frozenset = mock.MagicMock() - - self.frozenset_mock = mock.MagicMock() - self.frozenset_mock.return_value = self.whitelist_frozenset - - with self.patch_builtin('frozenset', self.frozenset_mock): - self.matcher = WhitelistMatcher(self.some_whitelist) - - def test_match_calls_in_on_whitelist(self): - ''' - Tests that match checks if the key is in the (frozen) whitelist - ''' - self.matcher.match(self.some_key) - - self.whitelist_frozenset.__contains__.assert_called_once_with( - self.some_key - ) - - def test_match_calls_returns_result_of_checkinf_in(self): - ''' - Tests that match returns the result of checking if the key is in the - (frozen) whitelist - ''' - self.assertEqual(self.whitelist_frozenset.__contains__.return_value, - self.matcher.match(self.some_key)) - - -class DateTimeBetweenMatcherTests(TestCase): - def setUp(self): - self.some_start = mock.MagicMock() - self.some_end = mock.MagicMock() - self.matcher = DateTimeBetweenMatcher(self.some_start, self.some_end) - - def test_matcher_is_between_matcher(self): - ''' - Tests that DateTimeBetweenMatcher is a BetweenMatcher - ''' - self.assertIsInstance(self.matcher, BetweenMatcher) - - def test_matcher_is_has_proper_transform_mixin(self): - ''' - Tests that DateTimeBetweenMatcher is a - AsDateHourMinuteTimestampTransformMixin - ''' - self.assertIsInstance( - self.matcher, AsDateHourMinuteTimestampTransformMixin - ) - - -class NumberBetweenMatcherTests(TestCase): - def setUp(self): - self.some_start = mock.MagicMock() - self.some_end = mock.MagicMock() - self.matcher = NumberBetweenMatcher(self.some_start, self.some_end) - - def test_matcher_is_between_matcher(self): - ''' - Tests that NumberBetweenMatcher is a BetweenMatcher - ''' - self.assertIsInstance(self.matcher, BetweenMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that NumberBetweenMatcher is a AsNumberTransformMixin - ''' - self.assertIsInstance(self.matcher, AsNumberTransformMixin) - - -class DateEqualToMatcherTests(TestCase): - def setUp(self): - self.some_compare_to = mock.MagicMock() - self.matcher = DateEqualToMatcher(self.some_compare_to) - - def test_matcher_is_between_matcher(self): - ''' - Tests that DateEqualToMatcher is a EqualToMatcher - ''' - self.assertIsInstance(self.matcher, EqualToMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that DateEqualToMatcher is a AsDateTimestampTransformMixin - ''' - self.assertIsInstance(self.matcher, AsDateTimestampTransformMixin) - - -class NumberToMatcherTests(TestCase): - def setUp(self): - self.some_compare_to = mock.MagicMock() - self.matcher = NumberEqualToMatcher(self.some_compare_to) - - def test_matcher_is_between_matcher(self): - ''' - Tests that NumberEqualToMatcher is a EqualToMatcher - ''' - self.assertIsInstance(self.matcher, EqualToMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that NumberEqualToMatcher is a AsNumberTransformMixin - ''' - self.assertIsInstance(self.matcher, AsNumberTransformMixin) - - -class DateTimeGreaterThanOrEqualToMatcherTests(TestCase): - def setUp(self): - self.some_compare_to = mock.MagicMock() - self.matcher = DateTimeGreaterThanOrEqualToMatcher(self.some_compare_to) - - def test_matcher_is_between_matcher(self): - ''' - Tests that DateTimeGreaterThanOrEqualToMatcher is a - GreaterThanOrEqualToMatcher - ''' - self.assertIsInstance(self.matcher, GreaterThanOrEqualToMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that DateTimeGreaterThanOrEqualToMatcher is a - AsDateHourMinuteTimestampTransformMixin - ''' - self.assertIsInstance( - self.matcher, AsDateHourMinuteTimestampTransformMixin - ) - - -class NumberGreaterThanOrEqualToMatcherTests(TestCase): - def setUp(self): - self.some_compare_to = mock.MagicMock() - self.matcher = NumberGreaterThanOrEqualToMatcher(self.some_compare_to) - - def test_matcher_is_between_matcher(self): - ''' - Tests that NumberGreaterThanOrEqualToMatcher is a - GreaterThanOrEqualToMatcher - ''' - self.assertIsInstance(self.matcher, GreaterThanOrEqualToMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that NumberGreaterThanOrEqualToMatcher is a AsNumberTransformMixin - ''' - self.assertIsInstance(self.matcher, AsNumberTransformMixin) - - -class DateTimeLessThanOrEqualToMatcherTests(TestCase): - def setUp(self): - self.some_compare_to = mock.MagicMock() - self.matcher = DateTimeLessThanOrEqualToMatcher(self.some_compare_to) - - def test_matcher_is_between_matcher(self): - ''' - Tests that DateTimeLessThanOrEqualToMatcher is a - LessThanOrEqualToMatcher - ''' - self.assertIsInstance(self.matcher, LessThanOrEqualToMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that DateTimeLessThanOrEqualToMatcher is a - AsDateHourMinuteTimestampTransformMixin - ''' - self.assertIsInstance( - self.matcher, AsDateHourMinuteTimestampTransformMixin - ) - - -class NumberLessThanOrEqualToMatcherTests(TestCase): - def setUp(self): - self.some_compare_to = mock.MagicMock() - self.matcher = NumberLessThanOrEqualToMatcher(self.some_compare_to) - - def test_matcher_is_between_matcher(self): - ''' - Tests that NumberLessThanOrEqualToMatcher is a LessThanOrEqualToMatcher - ''' - self.assertIsInstance(self.matcher, LessThanOrEqualToMatcher) - - def test_matcher_has_proper_transform_mixin(self): - ''' - Tests that NumberLessThanOrEqualToMatcher is a AsNumberTransformMixin - ''' - self.assertIsInstance(self.matcher, AsNumberTransformMixin) - - -class StartsWithMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'STARTS_WITH', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, StartsWithMatcher) - - def test_keys_with_prefix_match(self): - ''' - Test that keys starting with one of the prefixes in the condition match - ''' - self.assertTrue(self._matcher.match('ABCtest')) - self.assertTrue(self._matcher.match('DEFtest')) - self.assertTrue(self._matcher.match('GHItest')) - - def test_keys_without_prefix_dont_match(self): - ''' - Test that keys that dont start with one of the prefixes don't match. - ''' - self.assertFalse(self._matcher.match('JKLtest')) - self.assertFalse(self._matcher.match('123test')) - self.assertFalse(self._matcher.match('dl_test')) - - def test_empty_string_doesnt_match(self): - ''' - Tests that the empty string doesn't match. - ''' - self.assertFalse(self._matcher.match('')) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class EndsWithMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'ENDS_WITH', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, EndsWithMatcher) - - def test_keys_with_suffix(self): - ''' - Test that keys starting with one of the prefixes in the condition match - ''' - self.assertTrue(self._matcher.match('testABC')) - self.assertTrue(self._matcher.match('testDEF')) - self.assertTrue(self._matcher.match('testGHI')) - - def test_keys_without_suffix_dont_match(self): - ''' - Test that keys that dont start with one of the prefixes don't match. - ''' - self.assertFalse(self._matcher.match('testJKL')) - self.assertFalse(self._matcher.match('test123')) - self.assertFalse(self._matcher.match('testdl_')) - - def test_empty_string_doesnt_match(self): - ''' - Tests that the empty string doesn't match. - ''' - self.assertFalse(self._matcher.match('')) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class ContainsStringMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'CONTAINS_STRING', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, ContainsStringMatcher) - - def test_keys_with_string(self): - ''' - Test that keys starting with one of the prefixes in the condition match - ''' - self.assertTrue(self._matcher.match('testABC')) - self.assertTrue(self._matcher.match('testDEFabc')) - self.assertTrue(self._matcher.match('GHI3214')) - - def test_keys_without_string_dont_match(self): - ''' - Test that keys that dont start with one of the prefixes don't match. - ''' - self.assertFalse(self._matcher.match('testJKL')) - self.assertFalse(self._matcher.match('test123')) - self.assertFalse(self._matcher.match('testdl_')) - - def test_empty_string_doesnt_match(self): - ''' - Tests that the empty string doesn't match. - ''' - self.assertFalse(self._matcher.match('')) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class ContainsAllOfSetMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'CONTAINS_ALL_OF_SET', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, ContainsAllOfSetMatcher) - - def test_set_with_all_keys(self): - ''' - Test that a set with all the keys matches - ''' - self.assertTrue(self._matcher.match(['ABC', 'DEF', 'GHI'])) - self.assertTrue(self._matcher.match(['ABC', 'DEF', 'GHI', 'AWE'])) - - def test_set_without_all_keys_doesnt_match(self): - ''' - Test that a set without all the keys doesn't match - ''' - self.assertFalse(self._matcher.match(['ABC', 'DEF'])) - self.assertFalse(self._matcher.match(['GHI'])) - - def test_empty_set_doesnt_match(self): - ''' - Tests that an empty set doesn't match - ''' - self.assertFalse(self._matcher.match([])) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class ContainsAnyOfSetMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'CONTAINS_ANY_OF_SET', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, ContainsAnyOfSetMatcher) - - def test_set_with_at_least_one_key(self): - ''' - Test that a set with at least one key matches - ''' - self.assertTrue(self._matcher.match(['ABC', 'DEF', 'GHI'])) - self.assertTrue(self._matcher.match(['ABC', 'DEF', 'GHI', 'AWE'])) - self.assertTrue(self._matcher.match(['ABC', 'DEF'])) - - def test_set_without_any_key_doesnt_match(self): - ''' - Test that a set without any the keys doesn't match - ''' - self.assertFalse(self._matcher.match(['AWE'])) - - def test_empty_set_doesnt_match(self): - ''' - Tests that an empty set doesn't match - ''' - self.assertFalse(self._matcher.match([])) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class EqualToSetMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'EQUAL_TO_SET', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, EqualToSetMatcher) - - def test_equal_set_matches(self): - ''' - Test that the exact same set matches - ''' - self.assertTrue(self._matcher.match(['ABC', 'DEF', 'GHI'])) - - def test_different_set_doesnt_match(self): - ''' - Test that a different set doesn't match - ''' - self.assertFalse(self._matcher.match(['ABC', 'DEF', 'GHI', 'AWE'])) - self.assertFalse(self._matcher.match(['ABC'])) - - def test_empty_set_doesnt_match(self): - ''' - Tests that an empty set doesn't match - ''' - self.assertFalse(self._matcher.match([])) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class PartOfSetMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'PART_OF_SET', - 'whitelistMatcherData': {'whitelist': ['ABC', 'DEF', 'GHI']} - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, PartOfSetMatcher) - - def test_subset_of_set_matches(self): - ''' - Test that a subset of the set matches - ''' - self.assertTrue(self._matcher.match(['ABC', 'DEF', 'GHI'])) - self.assertTrue(self._matcher.match(['ABC'])) - - def test_not_subset_of_set_doesnt_match(self): - ''' - Test that any set with elements that are not in the split's set doesn't - match - ''' - self.assertFalse(self._matcher.match(['ABC', 'DEF', 'GHI', 'AWE'])) - self.assertFalse(self._matcher.match(['RFV'])) - - def test_empty_set_doesnt_match(self): - ''' - Tests that an empty set doesn't match - ''' - self.assertFalse(self._matcher.match([])) - - def test_none_doesnt_match(self): - ''' - Tests that None doesn't match. - ''' - self.assertFalse(self._matcher.match(None)) - - -class DependencyMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'IN_SPLIT_TREATMENT', - 'dependencyMatcherData': { - 'split': 'someSplit', - 'treatments': ['on'] - } - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - self._mock = self.patch('splitio.evaluator.Evaluator') - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, DependencyMatcher) - - def test_matcher_client_is_created_and_evaluate_treatment_called(self): - self._matcher.match('abc', None, self._mock) - self._mock.evaluate_treatment.assert_called_once_with('someSplit', 'abc', None, None) - self.assertTrue(True) - - -class RegexMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'MATCHES_STRING', - 'stringMatcherData': '[a-z]' - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, RegexMatcher) - - def test_regexes(self): - ''' - Test different regexes lodeded from regex.txt - ''' - current_path = os.path.dirname(__file__) - with open(os.path.join(current_path, 'regex.txt')) as flo: - lines = [line for line in flo] - lines.pop() # Remove empy last line - for line in lines: - regex, text, res = line.split('#') - matcher = RegexMatcher(regex) - print(regex, text, res) - self.assertEquals(matcher.match(text), json.loads(res)) - - -class BooleanMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self._split_parser = SplitParser(object()) - matcher = { - 'matcherType': 'EQUAL_TO_BOOLEAN', - 'booleanMatcherData': True - } - split = {'conditions': [{'matcher': matcher}]} - self._matcher = (self._split_parser._parse_matcher(split, matcher) - ._matcher.delegate) - - def test_matcher_construction(self): - ''' - Tests that the correct matcher matcher is constructed. - ''' - self.assertIsInstance(self._matcher, BooleanMatcher) - - def test_different_keys(self): - ''' - Test how different types get parsed - ''' - self.assertTrue(self._matcher.match(True)) - self.assertTrue(self._matcher.match('tRue')) - self.assertFalse(self._matcher.match(False)) - self.assertFalse(self._matcher.match('False')) - self.assertFalse(self._matcher.match('')) - self.assertFalse(self._matcher.match({})) diff --git a/splitio/tests/test_metrics.py b/splitio/tests/test_metrics.py deleted file mode 100644 index 4766ed26..00000000 --- a/splitio/tests/test_metrics.py +++ /dev/null @@ -1,721 +0,0 @@ -"""Unit tests for the metrics module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase - -from splitio.metrics import (LatencyTracker, InMemoryMetrics, build_metrics_counter_data, - build_metrics_times_data, build_metrics_gauge_data, ApiMetrics, - AsyncMetrics, get_latency_bucket_index) -from splitio.tests.utils import MockUtilsMixin - - -class LatencyTrackerFindIndexTests(TestCase): - def setUp(self): - self.latency_tracker = LatencyTracker() - - def test_works_with_0(self): - """Test that _find_index works with 0""" - self.assertEqual(0, get_latency_bucket_index(0)) - - def test_works_with_max(self): - """Test that _find_index works with max_latency""" - self.assertEqual(22, get_latency_bucket_index(7481828)) - - def test_works_with_values_over_max(self): - """Test that _find_index works with values_over_max_latency""" - self.assertEqual(22, get_latency_bucket_index(7481829)) - self.assertEqual(22, get_latency_bucket_index(8481829)) - - def test_works_with_values_between_0_and_max(self): - """Test that _find_index works with values between 0 and max""" - self.assertEqual(0, get_latency_bucket_index(500)) - self.assertEqual(0, get_latency_bucket_index(1000)) - self.assertEqual(1, get_latency_bucket_index(1250)) - self.assertEqual(1, get_latency_bucket_index(1500)) - self.assertEqual(2, get_latency_bucket_index(2000)) - self.assertEqual(2, get_latency_bucket_index(2250)) - self.assertEqual(3, get_latency_bucket_index(3000)) - self.assertEqual(3, get_latency_bucket_index(3375)) - self.assertEqual(4, get_latency_bucket_index(4000)) - self.assertEqual(4, get_latency_bucket_index(5063)) - self.assertEqual(5, get_latency_bucket_index(6000)) - self.assertEqual(5, get_latency_bucket_index(7594)) - self.assertEqual(6, get_latency_bucket_index(10000)) - self.assertEqual(6, get_latency_bucket_index(11391)) - self.assertEqual(7, get_latency_bucket_index(15000)) - self.assertEqual(7, get_latency_bucket_index(17086)) - self.assertEqual(8, get_latency_bucket_index(20000)) - self.assertEqual(8, get_latency_bucket_index(25629)) - self.assertEqual(9, get_latency_bucket_index(30000)) - self.assertEqual(9, get_latency_bucket_index(38443)) - self.assertEqual(10, get_latency_bucket_index(50000)) - self.assertEqual(10, get_latency_bucket_index(57665)) - self.assertEqual(11, get_latency_bucket_index(80000)) - self.assertEqual(11, get_latency_bucket_index(86498)) - self.assertEqual(12, get_latency_bucket_index(100000)) - self.assertEqual(12, get_latency_bucket_index(129746)) - self.assertEqual(13, get_latency_bucket_index(150000)) - self.assertEqual(13, get_latency_bucket_index(194620)) - self.assertEqual(14, get_latency_bucket_index(200000)) - self.assertEqual(14, get_latency_bucket_index(291929)) - self.assertEqual(15, get_latency_bucket_index(300000)) - self.assertEqual(15, get_latency_bucket_index(437894)) - self.assertEqual(16, get_latency_bucket_index(500000)) - self.assertEqual(16, get_latency_bucket_index(656841)) - self.assertEqual(17, get_latency_bucket_index(800000)) - self.assertEqual(17, get_latency_bucket_index(985261)) - self.assertEqual(18, get_latency_bucket_index(1000000)) - self.assertEqual(18, get_latency_bucket_index(1477892)) - self.assertEqual(19, get_latency_bucket_index(2000000)) - self.assertEqual(19, get_latency_bucket_index(2216838)) - self.assertEqual(20, get_latency_bucket_index(2500000)) - self.assertEqual(20, get_latency_bucket_index(3325257)) - self.assertEqual(21, get_latency_bucket_index(4000000)) - self.assertEqual(21, get_latency_bucket_index(4987885)) - self.assertEqual(22, get_latency_bucket_index(6000000)) - - -class LatencyTrackerTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_millis = 1000 - self.some_micros = 1000000 - self.some_latencies = list(range(23)) - self.latency_tracker = LatencyTracker() - self.get_latency_bucket_index_mock = self.patch( - 'splitio.metrics.get_latency_bucket_index', return_value=5) - - def test_add_latency_millis_calls_get_latency_bucket_index(self): - """Test that add_latency_millis calls _find_index""" - self.latency_tracker.add_latency_millis(self.some_millis) - self.get_latency_bucket_index_mock.assert_called_once_with(self.some_millis * 1000) - - def test_add_latency_millis_sets_right_element(self): - """Test that add_latency_millis adds 1 to the right element""" - self.latency_tracker.add_latency_millis(self.some_millis) - self.assertListEqual([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - self.latency_tracker._latencies) - - def test_add_latency_micros_calls_get_latency_bucket_index(self): - """Test that add_latency_micros calls _find_index""" - self.latency_tracker.add_latency_micros(self.some_micros) - self.get_latency_bucket_index_mock.assert_called_once_with(self.some_micros) - - def test_add_latency_micros_sets_right_element(self): - """Test that add_latency_micros adds 1 to the right element""" - self.latency_tracker.add_latency_micros(self.some_micros) - self.assertListEqual([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - self.latency_tracker._latencies) - - def test_clear_resets_latencies(self): - """Test that clear resets latencies""" - self.latency_tracker._latencies = self.some_latencies - self.latency_tracker.clear() - self.assertListEqual([0] * 23, self.latency_tracker._latencies) - - -class LatencyTrackerGetBucketTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_latency_millis = 1000 - self.some_latency_micros = 1000000 - self.some_latencies = [mock.MagicMock() for _ in range(23)] - self.latency_tracker = LatencyTracker(latencies=self.some_latencies) - self.get_latency_bucket_index_mock = self.patch( - 'splitio.metrics.get_latency_bucket_index', return_value=5) - - def test_get_bucket_for_latency_millis_calls_get_latency_bucket_index(self): - """Test that get_bucket_for_latency_millis calls _find_index""" - self.latency_tracker.get_bucket_for_latency_millis(self.some_latency_millis) - self.get_latency_bucket_index_mock.assert_called_once_with(self.some_latency_millis * 1000) - - def test_get_bucket_for_latency_millis_returns_right_element(self): - """Test that get_bucket_for_latency_millis returns the right element""" - self.assertEqual(self.some_latencies[self.get_latency_bucket_index_mock.return_value], - self.latency_tracker.get_bucket_for_latency_millis( - self.some_latency_millis)) - - def test_get_bucket_for_latency_micros_calls_get_latency_bucket_index(self): - """Test that get_bucket_for_latency_micros calls _find_index""" - self.latency_tracker.get_bucket_for_latency_micros(self.some_latency_micros) - self.get_latency_bucket_index_mock.assert_called_once_with(self.some_latency_micros) - - def test_get_bucket_for_latency_micros_returns_right_element(self): - """Test that get_bucket_for_latency_micros returns the right element""" - self.assertEqual(self.some_latencies[self.get_latency_bucket_index_mock.return_value], - self.latency_tracker.get_bucket_for_latency_micros( - self.some_latency_micros)) - - -class InMemoryMetricsCountTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_counter = mock.MagicMock() - self.some_delta = 5 - self.rlock_mock = self.patch('splitio.metrics.RLock') - self.arrow_mock = self.patch('splitio.metrics.arrow') - self.arrow_mock.utcnow.return_value.timestamp = 1234567 - self.defaultdict_side_effect = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - self.defaultdict_side_effect[0].__getitem__.return_value = 3 - self.defaultdict_mock = self.patch('splitio.metrics.defaultdict', - side_effect=self.defaultdict_side_effect) - self.metrics = InMemoryMetrics() - self.update_count_mock = self.patch_object(self.metrics, 'update_count') - - def test_counter_not_set_if_ignore_metrics_is_true(self): - """Test that the counter is not set if ignore metrics is true""" - self.metrics._ignore_metrics = True - self.metrics.count(self.some_counter, self.some_delta) - self.defaultdict_side_effect[0].__setitem__.assert_not_called() - - def test_call_count_not_increased_if_ignore_metrics_is_true(self): - """Test that the call count is not increased if ignore metrics is true""" - self.metrics._ignore_metrics = True - self.metrics.count(self.some_counter, self.some_delta) - self.assertEqual(0, self.metrics._count_call_count) - - def test_increases_counter_by_delta(self): - """Test that the counter is increased by delta""" - self.metrics.count(self.some_counter, self.some_delta) - self.defaultdict_side_effect[0].__setitem__.assert_called_once_with( - self.some_counter, - self.defaultdict_side_effect[0].__getitem__.return_value + self.some_delta) - - def test_increases_call_count(self): - """Test that the call counte is increased by one""" - self.metrics.count(self.some_counter, self.some_delta) - self.assertEqual(1, self.metrics._count_call_count) - - def test_that_if_no_update_conditions_are_met_update_count_not_called(self): - """Test that if neither update conditions are met, update_count is not called""" - self.metrics.count(self.some_counter, self.some_delta) - self.update_count_mock.assert_not_called() - - def test_that_if_no_update_conditions_are_met_call_count_not_reset(self): - """Test that if neither update conditions are met, the call count is not reset""" - self.metrics.count(self.some_counter, self.some_delta) - self.assertLess(0, self.metrics._count_call_count) - - def test_update_count_called_if_max_call_count_reached(self): - """Test that update_count is called if max call count is reached""" - self.metrics._max_call_count = 5 - self.metrics._count_call_count = 4 - self.metrics.count(self.some_counter, self.some_delta) - self.update_count_mock.assert_called_once_with() - - def test_call_count_reset_if_max_call_count_reached(self): - """Test that call count is reset if max call count is reached""" - self.metrics._max_call_count = 5 - self.metrics._count_call_count = 4 - self.metrics.count(self.some_counter, self.some_delta) - self.assertEqual(0, self.metrics._count_call_count) - - def test_update_count_not_called_if_max_call_count_reached_and_ignore_metrics_is_true(self): - """Test that update_count is not called if max call count is reached but ignore_metrics - is True""" - self.metrics._max_call_count = 5 - self.metrics._count_call_count = 4 - self.metrics._ignore_metrics = True - self.metrics.count(self.some_counter, self.some_delta) - self.update_count_mock.assert_not_called() - - def test_call_count_not_reset_if_max_call_count_reached_and_ignore_metrics_is_true(self): - """Test that call count is not reset if max call count is reached but ignore_metrics is - True""" - self.metrics._max_call_count = 5 - self.metrics._count_call_count = 4 - self.metrics._ignore_metrics = True - self.metrics.count(self.some_counter, self.some_delta) - self.assertEqual(4, self.metrics._count_call_count) - - def test_update_count_called_if_max_time_between_calls_reached(self): - """Test that update_count is called if max time between calls reached""" - self.metrics._max_time_between_calls = 10 - self.metrics._count_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics.count(self.some_counter, self.some_delta) - self.update_count_mock.assert_called_once_with() - - def test_call_count_reset_if_max_time_between_calls_reached(self): - """Test that call count is reset if max time between calls reached""" - self.metrics._max_time_between_calls = 10 - self.metrics._count_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics.count(self.some_counter, self.some_delta) - self.assertEqual(0, self.metrics._count_call_count) - - def test_update_count_not_called_max_time_beween_calls_reached_and_ignore_metrics_is_true(self): - """Test that update_count is not called if max time between calls is reached but - ignore_metrics is True""" - self.metrics._max_time_between_calls = 10 - self.metrics._count_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics._ignore_metrics = True - self.metrics.count(self.some_counter, self.some_delta) - self.update_count_mock.assert_not_called() - - def test_call_count_not_reset_max_time_between_calls_reached_and_ignore_metrics_is_true(self): - """Test that call count is not reset if max time between calls is reached but - ignore_metrics is True""" - self.metrics._count_call_count = 4 - self.metrics._max_time_between_calls = 10 - self.metrics._count_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics._ignore_metrics = True - self.metrics.count(self.some_counter, self.some_delta) - self.assertEqual(4, self.metrics._count_call_count) - - -class InMemoryMetricsTimeTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_operation = mock.MagicMock() - self.some_time_in_ms = mock.MagicMock() - self.rlock_mock = self.patch('splitio.metrics.RLock') - self.arrow_mock = self.patch('splitio.metrics.arrow') - self.arrow_mock.utcnow.return_value.timestamp = 1234567 - self.defaultdict_side_effect = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - self.latency_tracker_mock = self.defaultdict_side_effect[1].__getitem__.return_value - self.defaultdict_mock = self.patch('splitio.metrics.defaultdict', - side_effect=self.defaultdict_side_effect) - self.metrics = InMemoryMetrics() - self.update_time_mock = self.patch_object(self.metrics, 'update_time') - - def test_add_latency_millis_not_called_if_ignore_metrics_is_true(self): - """Test that the add_latency_millis is not calledif ignore metrics is true""" - self.metrics._ignore_metrics = True - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.defaultdict_side_effect[1].__getitem__.return_value.\ - add_latency_millis.assert_not_called() - - def test_call_count_not_increased_if_ignore_metrics_is_true(self): - """Test that the call count is not increased if ignore metrics is true""" - self.metrics._ignore_metrics = True - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertEqual(0, self.metrics._time_call_count) - - def test_calls_add_latency_millis(self): - """Test that the add_latency_millis is called""" - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.defaultdict_side_effect[1].__getitem__.return_value.\ - add_latency_millis.assert_called_once_with(self.some_time_in_ms) - - def test_increases_call_count(self): - """Test that the call count is increased by one""" - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertEqual(1, self.metrics._time_call_count) - - def test_that_if_no_update_conditions_are_met_update_count_not_called(self): - """Test that if neither update conditions are met, update_count is not called""" - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.update_time_mock.assert_not_called() - - def test_that_if_no_update_conditions_are_met_call_count_not_reset(self): - """Test that if neither update conditions are met, the call count is not reset""" - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertLess(0, self.metrics._time_call_count) - - def test_update_count_called_if_max_call_count_reached(self): - """Test that update_count is called if max call count is reached""" - self.metrics._max_call_count = 5 - self.metrics._time_call_count = 4 - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.update_time_mock.assert_called_once_with() - - def test_call_count_reset_if_max_call_count_reached(self): - """Test that call count is reset if max call count is reached""" - self.metrics._max_call_count = 5 - self.metrics._time_call_count = 4 - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertEqual(0, self.metrics._time_call_count) - - def test_update_count_not_called_if_max_call_count_reached_and_ignore_metrics_is_true(self): - """Test that update_count is not called if max call count is reached but ignore_metrics - is True""" - self.metrics._max_call_count = 5 - self.metrics._time_call_count = 4 - self.metrics._ignore_metrics = True - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.update_time_mock.assert_not_called() - - def test_call_count_not_reset_if_max_call_count_reached_and_ignore_metrics_is_true(self): - """Test that call count is not reset if max call count is reached but ignore_metrics is - True""" - self.metrics._max_call_count = 5 - self.metrics._time_call_count = 4 - self.metrics._ignore_metrics = True - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertEqual(4, self.metrics._time_call_count) - - def test_update_count_called_if_max_time_between_calls_reached(self): - """Test that update_count is called if max time between calls reached""" - self.metrics._max_time_between_calls = 10 - self.metrics._time_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.update_time_mock.assert_called_once_with() - - def test_call_count_reset_if_max_time_between_calls_reached(self): - """Test that call count is reset if max time between calls reached""" - self.metrics._max_time_between_calls = 10 - self.metrics._time_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertEqual(0, self.metrics._time_call_count) - - def test_update_count_not_called_max_time_beween_calls_reached_and_ignore_metrics_is_true(self): - """Test that update_count is not called if max time between calls is reached but - ignore_metrics is True""" - self.metrics._max_time_between_calls = 10 - self.metrics._time_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics._ignore_metrics = True - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.update_time_mock.assert_not_called() - - def test_call_count_not_reset_max_time_between_calls_reached_and_ignore_metrics_is_true(self): - """Test that call count is not reset if max time between calls is reached but - ignore_metrics is True""" - self.metrics._time_call_count = 4 - self.metrics._max_time_between_calls = 10 - self.metrics._time_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics._ignore_metrics = True - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.assertEqual(4, self.metrics._time_call_count) - - -class InMemoryMetricsGaugeTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_gauge = mock.MagicMock() - self.some_value = mock.MagicMock() - self.rlock_mock = self.patch('splitio.metrics.RLock') - self.arrow_mock = self.patch('splitio.metrics.arrow') - self.arrow_mock.utcnow.return_value.timestamp = 1234567 - self.defaultdict_side_effect = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - self.defaultdict_mock = self.patch('splitio.metrics.defaultdict', - side_effect=self.defaultdict_side_effect) - self.metrics = InMemoryMetrics() - self.update_gauge_mock = self.patch_object(self.metrics, 'update_gauge') - - def test_value_not_set_if_ignore_metrics_is_true(self): - """Test that the value is not set if ignore metrics is true""" - self.metrics._ignore_metrics = True - self.metrics.gauge(self.some_gauge, self.some_value) - self.defaultdict_side_effect[2].__setitem__.assert_not_called() - - def test_call_count_not_increased_if_ignore_metrics_is_true(self): - """Test that the call count is not increased if ignore metrics is true""" - self.metrics._ignore_metrics = True - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertEqual(0, self.metrics._count_call_count) - - def test_increases_counter_by_delta(self): - """Test that the gauge is set to value""" - self.metrics.gauge(self.some_gauge, self.some_value) - self.defaultdict_side_effect[2].__setitem__.assert_called_once_with( - self.some_gauge, - self.some_value) - - def test_increases_call_count(self): - """Test that the call count is increased by one""" - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertEqual(1, self.metrics._gauge_call_count) - - def test_that_if_no_update_conditions_are_met_update_count_not_called(self): - """Test that if neither update conditions are met, update_count is not called""" - self.metrics.gauge(self.some_gauge, self.some_value) - self.update_gauge_mock.assert_not_called() - - def test_that_if_no_update_conditions_are_met_call_count_not_reset(self): - """Test that if neither update conditions are met, the call count is not reset""" - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertLess(0, self.metrics._gauge_call_count) - - def test_update_count_called_if_max_call_count_reached(self): - """Test that update_count is called if max call count is reached""" - self.metrics._max_call_count = 5 - self.metrics._gauge_call_count = 4 - self.metrics.gauge(self.some_gauge, self.some_value) - self.update_gauge_mock.assert_called_once_with() - - def test_call_count_reset_if_max_call_count_reached(self): - """Test that call count is reset if max call count is reached""" - self.metrics._max_call_count = 5 - self.metrics._gauge_call_count = 4 - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertEqual(0, self.metrics._gauge_call_count) - - def test_update_count_not_called_if_max_call_count_reached_and_ignore_metrics_is_true(self): - """Test that update_count is not called if max call count is reached but ignore_metrics - is True""" - self.metrics._max_call_count = 5 - self.metrics._gauge_call_count = 4 - self.metrics._ignore_metrics = True - self.metrics.gauge(self.some_gauge, self.some_value) - self.update_gauge_mock.assert_not_called() - - def test_call_count_not_reset_if_max_call_count_reached_and_ignore_metrics_is_true(self): - """Test that call count is not reset if max call count is reached but ignore_metrics is - True""" - self.metrics._max_call_count = 5 - self.metrics._gauge_call_count = 4 - self.metrics._ignore_metrics = True - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertEqual(4, self.metrics._gauge_call_count) - - def test_update_count_called_if_max_time_between_calls_reached(self): - """Test that update_count is called if max time between calls reached""" - self.metrics._max_time_between_calls = 10 - self.metrics._gauge_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics.gauge(self.some_gauge, self.some_value) - self.update_gauge_mock.assert_called_once_with() - - def test_call_count_reset_if_max_time_between_calls_reached(self): - """Test that call count is reset if max time between calls reached""" - self.metrics._max_time_between_calls = 10 - self.metrics._gauge_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertEqual(0, self.metrics._gauge_call_count) - - def test_update_count_not_called_max_time_beween_calls_reached_and_ignore_metrics_is_true(self): - """Test that update_count is not called if max time between calls is reached but - ignore_metrics is True""" - self.metrics._max_time_between_calls = 10 - self.metrics._gauge_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics._ignore_metrics = True - self.metrics.gauge(self.some_gauge, self.some_value) - self.update_gauge_mock.assert_not_called() - - def test_call_count_not_reset_max_time_between_calls_reached_and_ignore_metrics_is_true(self): - """Test that call count is not reset if max time between calls is reached but - ignore_metrics is True""" - self.metrics._gauge_call_count = 4 - self.metrics._max_time_between_calls = 10 - self.metrics._gauge_last_call_time = 100 - self.arrow_mock.utcnow.return_value.timestamp = 1000 - self.metrics._ignore_metrics = True - self.metrics.gauge(self.some_gauge, self.some_value) - self.assertEqual(4, self.metrics._gauge_call_count) - - -class BuildMetricsCounterDataTests(TestCase): - def test_works_with_empty_data(self): - """Test that build_metrics_counter_data works with empty data""" - self.assertListEqual([], build_metrics_counter_data(dict())) - - def test_works_with_non_empty_data(self): - """Tests that build_metrics_counter_data works with non-empty data""" - count_metrics = {'some_name': mock.MagicMock(), 'some_other_name': mock.MagicMock()} - self.assertListEqual( - [{'name': 'some_name', 'delta': count_metrics['some_name']}, - {'name': 'some_other_name', 'delta': count_metrics['some_other_name']}], - sorted(build_metrics_counter_data(count_metrics), key=lambda d: d['name']) - ) - - -class BuildMetricsTimeDataTests(TestCase): - def test_works_with_empty_data(self): - """Test that build_metrics_time_data works with empty data""" - self.assertListEqual([], build_metrics_times_data(dict())) - - def test_works_with_non_empty_data(self): - """Tests that build_metrics_counter_data works with non-empty data""" - times_metrics = {'some_name': mock.MagicMock(), 'some_other_name': mock.MagicMock()} - self.assertListEqual( - [{'name': 'some_name', - 'latencies': times_metrics['some_name'].get_latencies.return_value}, - {'name': 'some_other_name', - 'latencies': times_metrics['some_other_name'].get_latencies.return_value}], - sorted(build_metrics_times_data(times_metrics), key=lambda d: d['name']) - ) - - -class BuildMetricsGagueDataTests(TestCase): - def test_works_with_empty_data(self): - """Test that build_metrics_gauge_data works with empty data""" - self.assertListEqual([], build_metrics_gauge_data(dict())) - - def test_works_with_non_empty_data(self): - """Tests that build_metrics_gauge_data works with non-empty data""" - gauge_metrics = {'some_name': mock.MagicMock(), 'some_other_name': mock.MagicMock()} - self.assertListEqual( - [{'name': 'some_name', 'value': gauge_metrics['some_name']}, - {'name': 'some_other_name', 'value': gauge_metrics['some_other_name']}], - sorted(build_metrics_gauge_data(gauge_metrics), key=lambda d: d['name']) - ) - - -class ApiMetricsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_count_metrics = mock.MagicMock() - self.some_time_metrics = mock.MagicMock() - self.some_gauge_metrics = mock.MagicMock() - self.some_api = mock.MagicMock() - self.rlock_mock = self.patch('splitio.metrics.RLock') - self.thread_pool_executor_mock = self.patch('splitio.metrics.ThreadPoolExecutor') - self.build_metrics_counter_data_mock = self.patch( - 'splitio.metrics.build_metrics_counter_data') - self.build_metrics_times_data_mock = self.patch( - 'splitio.metrics.build_metrics_times_data') - self.build_metrics_gauge_data_mock = self.patch( - 'splitio.metrics.build_metrics_gauge_data') - self.metrics = ApiMetrics(self.some_api, count_metrics=self.some_count_metrics, - time_metrics=self.some_time_metrics, - gauge_metrics=self.some_gauge_metrics) - - def test_update_count_fn_resets_count_metrics(self): - """Test that update_count_fn resets count metrics""" - self.metrics._update_count_fn() - self.assertDictEqual(dict(), self.metrics._count_metrics) - - def test_update_count_fn_calls_metrics_counters(self): - """Test that update_count_fn calls metrics_counters""" - self.metrics._update_count_fn() - self.some_api.metrics_counters.assert_called_once_with( - self.build_metrics_counter_data_mock.return_value) - - def test_update_count_fn_sets_ignore_if_metrics_counters_raises_exception(self): - """Test that if metrics_counters raises an exception ignore_metrics is set to True""" - self.some_api.metrics_counters.side_effect = Exception() - self.metrics._update_count_fn() - self.assertTrue(self.metrics._ignore_metrics) - - def test_update_time_fn_resets_time_metrics(self): - """Test that update_count_fn resets time metrics""" - self.metrics._update_time_fn() - self.assertDictEqual(dict(), self.metrics._time_metrics) - - def test_update_time_fn_calls_metrics_counters(self): - """Test that update_count_fn calls metrics_times""" - self.metrics._update_time_fn() - self.some_api.metrics_times.assert_called_once_with( - self.build_metrics_times_data_mock.return_value) - - def test_update_time_fn_sets_ignore_if_metrics_counters_raises_exception(self): - """Test that if metrics_times raises an exception ignore_metrics is set to True""" - self.some_api.metrics_times.side_effect = Exception() - self.metrics._update_time_fn() - self.assertTrue(self.metrics._ignore_metrics) - - def test_update_gauge_fn_resets_time_metrics(self): - """Test that _update_gauge_fn resets gauge metrics""" - self.metrics._update_gauge_fn() - self.assertDictEqual(dict(), self.metrics._gauge_metrics) - - def test_update_gauge_fn_calls_metrics_counters(self): - """Test that _update_gauge_fn calls metrics_times""" - self.metrics._update_gauge_fn() - self.some_api.metrics_gauge.assert_called_once_with( - self.build_metrics_gauge_data_mock.return_value) - - def test_update_gauge_fn_sets_ignore_if_metrics_counters_raises_exception(self): - """Test that if metrics_gauge raises an exception ignore_metrics is set to True""" - self.some_api.metrics_gauge.side_effect = Exception() - self.metrics._update_gauge_fn() - self.assertTrue(self.metrics._ignore_metrics) - - def test_update_count_calls_submit(self): - """Test that update_count calls thread pool executor submit""" - self.metrics.update_count() - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.metrics._update_count_fn) - - def test_update_count_doesnt_raise_exceptions(self): - """Test that update_count doesn't raise an exception when submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - - try: - self.metrics.update_count() - except: - self.fail('Unexpected exception raised') - - def test_update_time_calls_submit(self): - """Test that update_time calls thread pool executor submit""" - self.metrics.update_time() - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.metrics._update_time_fn) - - def test_update_time_doesnt_raise_exceptions(self): - """Test that update_time doesn't raise an exception when submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - - try: - self.metrics.update_time() - except: - self.fail('Unexpected exception raised') - - def test_update_gauge_calls_submit(self): - """Test that update_gauge calls thread pool executor submit""" - self.metrics.update_gauge() - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.metrics._update_gauge_fn) - - def test_update_gauge_doesnt_raise_exceptions(self): - """Test that update_gauge doesn't raise an exception when submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - - try: - self.metrics.update_gauge() - except: - self.fail('Unexpected exception raised') - - -class AsyncMetricsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_counter = mock.MagicMock() - self.some_delta = mock.MagicMock() - self.some_operation = mock.MagicMock() - self.some_time_in_ms = mock.MagicMock() - self.some_gauge = mock.MagicMock() - self.some_value = mock.MagicMock() - self.some_delegate_metrics = mock.MagicMock() - self.thread_pool_executor_mock = self.patch('splitio.metrics.ThreadPoolExecutor') - self.metrics = AsyncMetrics(self.some_delegate_metrics) - - def test_count_calls_submit_with_delegate_count(self): - """Test that count calls submit with the delegate count""" - self.metrics.count(self.some_counter, self.some_delta) - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.some_delegate_metrics.count, self.some_counter, self.some_delta) - - def test_count_doesnt_raise_exceptions_if_submit_does(self): - """Test that count doesnt't raise an exception even if submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - try: - self.metrics.count(self.some_counter, self.some_delta) - except: - self.fail('Unexpected exception raised') - - def test_time_calls_submit_with_delegate_time(self): - """Test that time calls submit with the delegate time""" - self.metrics.time(self.some_operation, self.some_time_in_ms) - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.some_delegate_metrics.time, self.some_operation, self.some_time_in_ms) - - def test_time_doesnt_raise_exceptions_if_submit_does(self): - """Test that time doesnt't raise an exception even if submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - try: - self.metrics.time(self.some_operation, self.some_time_in_ms) - except: - self.fail('Unexpected exception raised') - - def test_guage_calls_submit_with_delegate_guage(self): - """Test that gauge calls submit with the delegate gauge""" - self.metrics.gauge(self.some_gauge, self.some_value) - self.thread_pool_executor_mock.return_value.submit.assert_called_once_with( - self.some_delegate_metrics.gauge, self.some_gauge, self.some_value) - - def test_gauge_doesnt_raise_exceptions_if_submit_does(self): - """Test that gauge doesnt't raise an exception even if submit does""" - self.thread_pool_executor_mock.return_value.submit.side_effect = Exception() - try: - self.metrics.gauge(self.some_gauge, self.some_value) - except: - self.fail('Unexpected exception raised') diff --git a/splitio/tests/test_prefix_decorator.py b/splitio/tests/test_prefix_decorator.py deleted file mode 100644 index ee779cec..00000000 --- a/splitio/tests/test_prefix_decorator.py +++ /dev/null @@ -1,133 +0,0 @@ -''' -Unit tests for the prefix decorator class -''' - -from __future__ import absolute_import, division, print_function, \ - unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase -from splitio.prefix_decorator import PrefixDecorator - - -class PrefixDecoratorTests(TestCase): - - def setUp(self): - self._prefix = 'test' - self._pattern = 'some_pattern' - self._prefixed_mock = mock.Mock() - self._unprefixed_mock = mock.Mock() - self._prefixed = PrefixDecorator(self._prefixed_mock, self._prefix) - self._unprefixed = PrefixDecorator(self._unprefixed_mock) - - def test_keys(self): - self._prefixed.keys('some_pattern') - self._prefixed_mock.keys.assert_called_once_with('test.some_pattern') - self._unprefixed.keys('some_pattern') - self._unprefixed_mock.keys.assert_called_once_with('some_pattern') - - def test_set(self): - self._prefixed.set('key1', 1) - self._prefixed_mock.set.assert_called_once_with('test.key1', 1) - self._unprefixed.set('key1', 1) - self._unprefixed_mock.set.assert_called_once_with('key1', 1) - - def test_get(self): - self._prefixed.get('key1') - self._prefixed_mock.get.assert_called_once_with('test.key1') - self._unprefixed.get('key1') - self._unprefixed_mock.get.assert_called_once_with('key1') - - def test_setex(self): - self._prefixed.setex('key1', 1, 100) - self._prefixed_mock.setex.assert_called_once_with('test.key1', 1, 100) - self._unprefixed.setex('key1', 1, 100) - self._unprefixed_mock.setex.assert_called_once_with('key1', 1, 100) - - def test_delete(self): - self._prefixed.delete('key1') - self._prefixed_mock.delete.assert_called_once_with('test.key1') - self._unprefixed.delete('key1') - self._unprefixed_mock.delete.assert_called_once_with('key1') - - def test_exists(self): - self._prefixed.exists('key1') - self._prefixed_mock.exists.assert_called_once_with('test.key1') - self._unprefixed.exists('key1') - self._unprefixed_mock.exists.assert_called_once_with('key1') - - def test_mget(self): - self._prefixed.mget(['key1', 'key2']) - self._prefixed_mock.mget.assert_called_once_with( - ['test.key1', 'test.key2'] - ) - self._unprefixed.mget(['key1', 'key2']) - self._unprefixed_mock.mget.assert_called_once_with(['key1', 'key2']) - - def test_smembers(self): - self._prefixed.smembers('set1') - self._prefixed_mock.smembers.assert_called_once_with('test.set1') - self._unprefixed.smembers('set1') - self._unprefixed_mock.smembers.assert_called_once_with('set1') - - def test_sadd(self): - self._prefixed.sadd('set1', 1, 2, 3) - self._prefixed_mock.sadd.assert_called_once_with( - 'test.set1', - 1, - 2, - 3 - ) - self._unprefixed.sadd('set1', 1, 2, 3) - self._unprefixed_mock.sadd.assert_called_once_with('set1', 1, 2, 3) - - def test_srem(self): - self._prefixed.srem('set1', 1) - self._prefixed_mock.srem.assert_called_once_with('test.set1', 1) - self._unprefixed.srem('set1', 1) - self._unprefixed_mock.srem.assert_called_once_with('set1', 1) - - def test_sismember(self): - self._prefixed.sismember('set1', 1) - self._prefixed_mock.sismember.assert_called_once_with('test.set1', 1) - self._unprefixed.sismember('set1', 1) - self._unprefixed_mock.sismember.assert_called_once_with('set1', 1) - - def test_eval(self): - self._prefixed.eval('some_lua_script', 2, 'key1', 'key2') - self._prefixed_mock.eval.assert_called_once_with( - 'some_lua_script', 2, 'test.key1', 'test.key2' - ) - self._unprefixed.eval('some_lua_script', 2, 'key1', 'key2') - self._unprefixed_mock.eval.assert_called_once_with( - 'some_lua_script', 2, 'key1', 'key2' - ) - - def test_hset(self): - self._prefixed.hset('hash1', 'key', 1) - self._prefixed_mock.hset.assert_called_once_with('test.hash1', 'key', 1) - self._unprefixed.hset('hash1', 'key', 1) - self._unprefixed_mock.hset.assert_called_once_with('hash1', 'key', 1) - - def test_hget(self): - self._prefixed.hget('hash1', 'key') - self._prefixed_mock.hget.assert_called_once_with('test.hash1', 'key') - self._unprefixed.hget('hash1', 'key') - self._unprefixed_mock.hget.assert_called_once_with('hash1', 'key') - - def test_incr(self): - self._prefixed.incr('key1') - self._prefixed_mock.incr.assert_called_once_with('test.key1', 1) - self._unprefixed.incr('key1') - self._unprefixed_mock.incr.assert_called_once_with('key1', 1) - - def test_getset(self): - self._prefixed.getset('key1', 12) - self._prefixed_mock.getset.assert_called_once_with('test.key1', 12) - self._unprefixed.getset('key1', 12) - self._unprefixed_mock.getset.assert_called_once_with('key1', 12) diff --git a/splitio/tests/test_redis_cache.py b/splitio/tests/test_redis_cache.py deleted file mode 100644 index 3389300d..00000000 --- a/splitio/tests/test_redis_cache.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from os.path import dirname, join -from unittest import TestCase -from splitio.tests.utils import MockUtilsMixin -from json import load - -from splitio.redis_support import (RedisSplitCache, RedisSegmentCache, get_redis) -from redis import StrictRedis -from splitio.clients import Client -from splitio.prefix_decorator import PrefixDecorator -from splitio.brokers import RedisBroker - -class CacheInterfacesTests(TestCase): - def setUp(self): - self._segment_changes_file_name = join(dirname(__file__), 'segmentChanges.json') - self._split_changes_file_name = join(dirname(__file__), 'splitChanges.json') - - self._redis = get_redis({'redisPrefix': 'test'}) - self._redis_split_cache = RedisSplitCache(self._redis) - self._redis_segment_cache = RedisSegmentCache(self._redis) - - def test_split_cache_interface(self): - - with open(self._split_changes_file_name) as f: - self._json = load(f) - split_definition = self._json['splits'][0] - split_name = split_definition['name'] - - #Add and get Split - self._redis_split_cache.add_split(split_name, split_definition) - self.assertEqual(split_definition['name'], self._redis_split_cache.get_split(split_name).name) - self.assertEqual(split_definition['killed'], self._redis_split_cache.get_split(split_name).killed) - self.assertEqual(split_definition['seed'], self._redis_split_cache.get_split(split_name).seed) - - #Remove Split - self._redis_split_cache.remove_split(split_name) - self.assertIsNone(self._redis_split_cache.get_split(split_name)) - - #Change Number - self._redis_split_cache.set_change_number(1212) - self.assertEqual(1212, self._redis_split_cache.get_change_number()) - - # @TODO This tests should be removed regarding that this is not supported by redis now. - # def testSegmentCacheInterface(self): - # with open(self._segment_changes_file_name) as f: - # self._json = load(f) - # segment_name = self._json['name'] - # segment_change_number = self._json['till'] - # segment_keys = self._json['added'] - - # self._redis_segment_cache.set_change_number(segment_name, segment_change_number) - # self.assertEqual(segment_change_number, self._redis_segment_cache.get_change_number(segment_name)) - - # self._redis_segment_cache.add_keys_to_segment(segment_name, segment_keys) - # self.assertTrue(self._redis_segment_cache.is_in_segment(segment_name, segment_keys[0])) - - # self._redis_segment_cache.remove_keys_from_segment(segment_name, [segment_keys[0]]) - # self.assertFalse(self._redis_segment_cache.is_in_segment(segment_name, segment_keys[0])) - -class ReadOnlyRedisMock(PrefixDecorator): - - def __init__(self, *args, **kwargs): - """ - Bases on PrefixDecorator. - """ - PrefixDecorator.__init__(self, *args, **kwargs) - - def sadd(self, name, *values): - """ - Decorated sadd to simulate read only exception error. - """ - raise Exception('ReadOnlyError') - - -class RedisReadOnlyTest(TestCase, MockUtilsMixin): - def setUp(self): - self._some_config = mock.MagicMock() - self._split_changes_file_name = join(dirname(__file__), 'splitChangesReadOnly.json') - - with open(self._split_changes_file_name) as f: - self._json = load(f) - split_definition = self._json['splits'][0] - split_name = split_definition['name'] - - self._redis = get_redis({'redisPrefix': 'test'}) - - self._mocked_redis = ReadOnlyRedisMock(self._redis) - self._redis_split_cache = RedisSplitCache(self._redis) - self._redis_split_cache.add_split(split_name, split_definition) - self._client = Client(RedisBroker(self._mocked_redis, self._some_config)) - - self._impression = mock.MagicMock() - self._start = mock.MagicMock() - self._operation = mock.MagicMock() - - def test_redis_read_only_mode(self): - self.assertEqual(self._client.get_treatment('valid', 'test_read_only_1'), 'on') - self.assertEqual(self._client.get_treatment('invalid', 'test_read_only_1'), 'off') - self.assertEqual(self._client.get_treatment('valid', 'test_read_only_1_invalid'), 'control') diff --git a/splitio/tests/test_redis_support.py b/splitio/tests/test_redis_support.py deleted file mode 100644 index d3eac11f..00000000 --- a/splitio/tests/test_redis_support.py +++ /dev/null @@ -1,432 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase -from collections import defaultdict - -from splitio.version import __version__ -from splitio.metrics import BUCKETS -from splitio.impressions import Impression -from splitio.tests.utils import MockUtilsMixin - -from splitio.redis_support import (RedisSegmentCache, RedisSplitCache, RedisImpressionsCache, - RedisMetricsCache, RedisSplitParser) -from splitio.config import GLOBAL_KEY_PARAMETERS - - -class RedisSegmentCacheTests(TestCase): - def setUp(self): - self.some_segment_name = mock.MagicMock() - self.some_segment_name_str = 'some_segment_name' - self.some_segment_keys = [mock.MagicMock(), mock.MagicMock()] - self.some_key = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.some_redis = mock.MagicMock() - self.a_segment_cache = RedisSegmentCache(self.some_redis) - - def test_add_keys_to_segment_adds_keys_to_segment_set(self): - """Test that add_keys_to_segment adds the keys to the segment key set""" - self.a_segment_cache.add_keys_to_segment(self.some_segment_name_str, self.some_segment_keys) - self.some_redis.sadd.assert_called_once_with( - 'SPLITIO.segment.some_segment_name', self.some_segment_keys[0], - self.some_segment_keys[1]) - - def test_remove_keys_from_segment_remove_keys_from_segment_set(self): - """Test that remove_keys_from_segment removes the keys to the segment key set""" - self.a_segment_cache.remove_keys_from_segment(self.some_segment_name_str, - self.some_segment_keys) - self.some_redis.srem.assert_called_once_with( - 'SPLITIO.segment.some_segment_name', self.some_segment_keys[0], - self.some_segment_keys[1]) - - def test_is_in_segment_tests_whether_a_key_is_in_a_segments_key_set(self): - """Test that is_in_segment checks if a key is in a segment's key set""" - self.assertEqual(self.some_redis.sismember.return_value, - self.a_segment_cache.is_in_segment(self.some_segment_name_str, - self.some_key)) - self.some_redis.sismember.assert_called_once_with( - 'SPLITIO.segment.some_segment_name', self.some_key) - - def test_set_change_number_sets_segment_change_number_key(self): - """Test that set_change_number sets the segment's change number key""" - self.a_segment_cache.set_change_number(self.some_segment_name_str, self.some_change_number) - self.some_redis.set.assert_called_once_with( - 'SPLITIO.segment.some_segment_name.till', self.some_change_number) - - def test_get_change_number_gets_segment_change_number_key(self): - """Test that get_change_number gets the segment's change number key""" - self.some_redis.get.return_value = '1234' - result = self.a_segment_cache.get_change_number(self.some_segment_name_str) - self.assertEqual(int(self.some_redis.get.return_value), result) - self.assertIsInstance(result, int) - self.some_redis.get.assert_called_once_with( - 'SPLITIO.segment.some_segment_name.till') - - def test_get_change_number_returns_default_value_if_not_set(self): - """Test that get_change_number returns -1 if the value is not set""" - self.some_redis.get.return_value = None - self.assertEqual(-1, - self.a_segment_cache.get_change_number(self.some_segment_name_str)) - - -class RedisSplitCacheTests(TestCase, MockUtilsMixin): - def setUp(self): - self.decode_mock = self.patch('splitio.redis_support.decode') - self.encode_mock = self.patch('splitio.redis_support.encode') - self.some_split_name = mock.MagicMock() - self.some_split_name_str = 'some_split_name' - self.some_split = mock.MagicMock() - self.some_change_number = mock.MagicMock() - self.some_redis = mock.MagicMock() - self.a_split_cache = RedisSplitCache(self.some_redis) - - def test_set_change_number_sets_change_number_key(self): - """Test that set_change_number sets the change number key""" - self.a_split_cache.set_change_number(self.some_change_number) - self.some_redis.set.assert_called_once_with( - 'SPLITIO.splits.till', self.some_change_number, None) - - def test_get_change_number_gets_segment_change_number_key(self): - """Test that get_change_number gets the change number key""" - self.some_redis.get.return_value = '1234' - result = self.a_split_cache.get_change_number() - self.assertEqual(int(self.some_redis.get.return_value), result) - self.assertIsInstance(result, int) - self.some_redis.get.assert_called_once_with( - 'SPLITIO.splits.till') - - def test_get_change_number_returns_default_value_if_not_set(self): - """Test that get_change_number returns -1 if the value is not set""" - self.some_redis.get.return_value = None - self.assertEqual(-1, self.a_split_cache.get_change_number()) - - def test_add_split_sets_split_key_with_pickled_split(self): - """Test that add_split sets the split key with pickled split""" - self.a_split_cache.add_split(self.some_split_name_str, self.some_split) - self.encode_mock.assert_called_once_with(self.some_split) - self.some_redis.set.assert_called_once_with('SPLITIO.split.some_split_name', - self.encode_mock.return_value) - - def test_get_split_returns_none_if_not_cached(self): - """Test that if a split is not cached get_split returns None""" - self.some_redis.get.return_value = None - self.assertEqual(None, self.a_split_cache.get_split(self.some_split_name_str)) - self.some_redis.get.assert_called_once_with('SPLITIO.split.some_split_name') - - def test_remove_split_deletes_split_key(self): - """Test that remove_split deletes the split key""" - self.a_split_cache.remove_split(self.some_split_name_str) - self.some_redis.delete.assert_called_once_with('SPLITIO.split.some_split_name') - - -class RedisImpressionsCacheTests(TestCase, MockUtilsMixin): - def setUp(self): - self.encode_mock = self.patch('splitio.redis_support.encode') - self.decode_mock = self.patch('splitio.redis_support.decode') - self.some_impression = mock.MagicMock() - self.some_redis = mock.MagicMock() - self.an_impressions_cache = RedisImpressionsCache(self.some_redis) - self.build_impressions_dict_mock = self.patch_object(self.an_impressions_cache, - '_build_impressions_dict') - - def test_fetch_all_doesnt_call_build_impressions_dict_if_no_impressions_cached(self): - """Test that fetch_all doesn't call _build_impressions_dict if no impressions are cached""" - self.some_redis.lrange.return_value = None - self.assertDictEqual(dict(), self.an_impressions_cache.fetch_all()) - self.build_impressions_dict_mock.assert_not_called() - - def test_clear_deletes_impressions_key(self): - """Test that clear deletes impressions key""" - self.an_impressions_cache.clear() - self.some_redis.eval.assert_called_once_with( - "return redis.call('del', unpack(redis.call('keys', ARGV[1])))", - 0, - RedisImpressionsCache._get_impressions_key('*') - ) - - def test_fetch_all_and_clear_calls_eval_with_fetch_and_clear_script(self): - """Test that fetch_and_clear calls eval with the fetch and clear script""" - self.an_impressions_cache.fetch_all_and_clear() - self.some_redis.keys.assert_called_once_with( - RedisImpressionsCache._get_impressions_key('*') - ) - - def test_fetch_all_and_clear_doesnt_call_build_impressions_dict_if_no_impressions_cached(self): - """Test that fetch_all_and_clear doesn't call _build_impressions_dict if no impressions are - cached""" - self.some_redis.eval.return_value = None - self.assertDictEqual(dict(), self.an_impressions_cache.fetch_all_and_clear()) - self.build_impressions_dict_mock.assert_not_called() - - -class RedisImpressionsCacheBuildImpressionsDictTests(TestCase): - def setUp(self): - self.some_redis = mock.MagicMock() - self.an_impressions_cache = RedisImpressionsCache(self.some_redis) - - def _build_impression(self, feature_name): - return Impression(matching_key=mock.MagicMock(), feature_name=feature_name, - treatment=mock.MagicMock(), label=mock.MagicMock(), time=mock.MagicMock(), - change_number=mock.MagicMock(), bucketing_key=mock.MagicMock()) - - def test_build_impressions_dict(self): - """Test that _build_impressions_dict builds the dictionary properly""" - some_feature_name = mock.MagicMock() - some_other_feature_name = mock.MagicMock() - some_feature_impressions = [self._build_impression(some_feature_name)] - some_other_feature_impressions = [self._build_impression(some_other_feature_name), - self._build_impression(some_other_feature_name)] - - self.assertDictEqual({some_feature_name: some_feature_impressions, - some_other_feature_name: some_other_feature_impressions}, - self.an_impressions_cache._build_impressions_dict( - some_feature_impressions + some_other_feature_impressions)) - - -class RedisMetricsCacheTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_redis = mock.MagicMock() - self.some_count_str = 'some_count' - self.some_delta = mock.MagicMock() - self.some_operation_str = 'some_operation' - self.some_bucket_index = 15 - self.some_gauge_str = 'some_gauge' - self.some_value = mock.MagicMock() - self.a_metrics_cache = RedisMetricsCache(self.some_redis) - self.build_metrics_from_cache_response_mock = self.patch_object( - self.a_metrics_cache, '_build_metrics_from_cache_response') - - def test_get_latency_calls_get(self): - """Test that get_latency calls get in last position (22)""" - self.a_metrics_cache.get_latency(self.some_operation_str) - name = GLOBAL_KEY_PARAMETERS['instance-id'] - self.some_redis.get.assert_called_with( - 'SPLITIO/python-'+__version__+'/'+name+'/latency.some_operation.bucket.{}' - .format(22) - ) - - def test_get_latency_returns_result(self): - """Test that get_latency returns the result of calling get""" - self.some_redis.get.return_value = mock.MagicMock() - result = [self.some_redis.get.return_value] * 23 - self.assertListEqual(result, - self.a_metrics_cache.get_latency(self.some_operation_str)) - - def test_get_latency_sets_empty_results_to_zero(self): - """Test that get_latency sets the missing results from get to zero""" - self.some_redis.get.return_value = None - self.assertEqual(0, self.a_metrics_cache.get_latency(self.some_operation_str)[13]) - - def test_get_latency_bucket_counter_calls_get(self): - """Test that get_latency_bucket_counter calls get""" - self.a_metrics_cache.get_latency_bucket_counter(self.some_operation_str, - self.some_bucket_index) - name = GLOBAL_KEY_PARAMETERS['instance-id'] - self.some_redis.get.assert_called_once_with( - 'SPLITIO/python-'+__version__+'/'+name+'/latency.{0}.bucket.{1}'.format - ( - self.some_operation_str, - self.some_bucket_index - )) - - def test_get_latency_bucket_counter_returns_get_result(self): - """Test that get_latency_bucket_counter returns the result of calling get""" - self.assertEqual(1, - self.a_metrics_cache.get_latency_bucket_counter(self.some_operation_str, - self.some_bucket_index)) - - def test_get_latency_bucket_counter_returns_0_on_missing_value(self): - """Test that get_latency_bucket_counter returns 0 if the bucket value is not cached""" - self.some_redis.get.return_value = None - self.assertEqual(0, - self.a_metrics_cache.get_latency_bucket_counter(self.some_operation_str, - self.some_bucket_index)) - - def test_set_gauge_calls_hset(self): - """Test that set_gauge calls hset""" - self.a_metrics_cache.set_gauge(self.some_gauge_str, self.some_value) - self.some_redis.hset('SPLITIO.metrics.metric', 'gauge.some_gauge', self.some_value) - - def test_set_latency_bucket_counter_calls_set(self): - """Test that set_latency_bucket_counter calls conditional eval with set latency bucket - counter script""" - self.a_metrics_cache.set_latency_bucket_counter(self.some_operation_str, - self.some_bucket_index, - self.some_value) - self.some_redis.set.assert_called_once_with( - self.a_metrics_cache._get_latency_bucket_key( - self.some_operation_str, bucket_number=self.some_bucket_index), - self.some_value) - - def test_increment_latency_bucket_counter_calls_delta(self): - """Test that increment_latency_bucket_counter calls with the increment - latency bucket counter script""" - self.a_metrics_cache.increment_latency_bucket_counter(self.some_operation_str, - self.some_bucket_index, - self.some_delta) - self.some_redis.incr.assert_called_once_with( - self.a_metrics_cache._get_latency_bucket_key( - self.some_operation_str, self.some_bucket_index), - self.some_delta) - - def test_increment_latency_bucket_counter_calls_default_delta(self): - """Test that increment_latency_bucket_counter calls conditional eval with the increment - latency bucket counter script with the default delta""" - self.a_metrics_cache.increment_latency_bucket_counter(self.some_operation_str, - self.some_bucket_index) - self.some_redis.incr.assert_called_once_with( - self.a_metrics_cache._get_latency_bucket_key( - self.some_operation_str, self.some_bucket_index - ), - 1 - ) - - -class RedisMetricsCacheConditionalEvalTests(TestCase): - def setUp(self): - self.some_script = 'some_script' - self.some_num_keys = mock.MagicMock() - self.some_key = mock.MagicMock() - self.some_other_key = mock.MagicMock() - self.some_value = mock.MagicMock() - self.some_redis = mock.MagicMock() - self.a_metrics_cache = RedisMetricsCache(self.some_redis) - - -class RedisMetricsCacheBuildMetricsFromCacheResponseTests(TestCase): - def setUp(self): - self.some_redis = mock.MagicMock() - self.a_metrics_cache = RedisMetricsCache(self.some_redis) - - def test_returns_default_empty_dict_on_none_response(self): - """Test that _build_metrics_from_cache_response returns the empty default dict if response - is None""" - self.assertDictEqual({'count': [], 'gauge': []}, - self.a_metrics_cache._build_metrics_from_cache_response(None)) - - def test_returns_default_empty_dict_on_empty_response(self): - """Test that _build_metrics_from_cache_response returns the empty default dict if response - is empty""" - self.assertDictEqual({'count': [], 'gauge': []}, - self.a_metrics_cache._build_metrics_from_cache_response([])) - - def test_returns_count_metrics(self): - """Test that _build_metrics_from_cache_response returns count metrics""" - some_count = 'some_count' - some_count_value = mock.MagicMock() - some_other_count = 'some_other_count' - some_other_count_value = mock.MagicMock() - name = GLOBAL_KEY_PARAMETERS['instance-id'] - count_metrics = [ - 'SPLITIO/python-'+__version__+'/'+name+'/count.some_count', some_count_value, - 'SPLITIO/python-'+__version__+'/'+name+'/count.some_other_count', - some_other_count_value - ] - result_count_metrics = [{'name': some_count, 'delta': some_count_value}, - {'name': some_other_count, 'delta': some_other_count_value}] - self.assertDictEqual({'count': result_count_metrics, 'gauge': []}, - self.a_metrics_cache._build_metrics_from_cache_response(count_metrics)) - - def test_returns_time_metrics(self): - """Test that _build_metrics_from_cache_response returns time metrics""" - some_time = 'some_time' - some_time_latencies = [0] * 23 - some_time_latencies[2] = mock.MagicMock() - some_time_latencies[13] = mock.MagicMock() - - some_other_time = 'some_other_time' - some_other_time_latencies = [0] * 23 - some_other_time_latencies[0] = mock.MagicMock() - some_other_time_latencies[1] = mock.MagicMock() - some_other_time_latencies[20] = mock.MagicMock() - - time_metrics = defaultdict(lambda: [0] * len(BUCKETS)) - - time_metrics[some_time] = some_time_latencies - time_metrics[some_other_time] = some_other_time_latencies - - result_time_metris = [{'name': some_other_time, 'latencies': some_other_time_latencies}, - {'name': some_time, 'latencies': some_time_latencies}] - - self.assertListEqual(result_time_metris, - self.a_metrics_cache._build_metrics_times_data(time_metrics)) - - def test_returns_gauge_metrics(self): - """Test that _build_metrics_from_cache_response returns gauge metrics""" - some_gauge = 'some_gauge' - some_gauge_value = mock.MagicMock() - some_other_gauge = 'some_other_gauge' - some_other_gauge_value = mock.MagicMock() - name = GLOBAL_KEY_PARAMETERS['instance-id'] - gauge_metrics = [ - 'SPLITIO/python-'+__version__+'/'+name+'/gauge.some_gauge', some_gauge_value, - 'SPLITIO/python-'+__version__+'/'+name+'/gauge.some_other_gauge', - some_other_gauge_value - ] - result_gauge_metrics = [{'name': some_gauge, 'value': some_gauge_value}, - {'name': some_other_gauge, 'value': some_other_gauge_value}] - self.assertDictEqual({'count': [], 'gauge': result_gauge_metrics}, - self.a_metrics_cache._build_metrics_from_cache_response(gauge_metrics)) - - -class RedisSplitParserTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_matcher = mock.MagicMock() - self.some_segment_cache = mock.MagicMock() - self.split_parser = RedisSplitParser(self.some_segment_cache) - self.redis_split_mock = self.patch('splitio.redis_support.RedisSplit') - - self.some_split = { - 'name': mock.MagicMock(), - 'seed': mock.MagicMock(), - 'killed': mock.MagicMock(), - 'defaultTreatment': mock.MagicMock(), - 'trafficTypeName': mock.MagicMock(), - 'status': mock.MagicMock(), - 'changeNumber': mock.MagicMock(), - 'algo': mock.MagicMock() - } - self.some_block_until_ready = mock.MagicMock() - self.some_partial_split = mock.MagicMock() - self.some_in_segment_matcher = { - 'matcherType': 'IN_SEGMENT', - 'userDefinedSegmentMatcherData': { - 'segmentName': mock.MagicMock() - } - } - - def test_parse_split_returns_redis_split(self): - """Test that _parse_split returns a RedisSplit""" - self.assertEqual(self.redis_split_mock.return_value, - self.split_parser._parse_split( - self.redis_split_mock, block_until_ready=self.some_block_until_ready)) - - def test_parse_split_calls_redis_split_constructor(self): - """Test that _parse_split calls RedisSplit constructor""" - self.split_parser._parse_split(self.some_split, - block_until_ready=self.some_block_until_ready) - self.redis_split_mock.assert_called_once_with( - self.some_split['name'], - self.some_split['seed'], - self.some_split['killed'], - self.some_split['defaultTreatment'], - self.some_split['trafficTypeName'], - self.some_split['status'], - self.some_split['changeNumber'], - segment_cache=self.some_segment_cache, - traffic_allocation=self.some_split.get('trafficAllocation'), - traffic_allocation_seed=self.some_split.get('trafficAllocationSeed'), - algo=self.some_split['algo'] - ) - - def test_parse_matcher_in_segment_registers_segment(self): - """Test that _parse_matcher_in_segment registers segment""" - self.split_parser._parse_matcher_in_segment(self.some_partial_split, - self.some_in_segment_matcher, - block_until_ready=self.some_block_until_ready) - self.some_segment_cache.register_segment.assert_called() diff --git a/splitio/tests/test_segments.py b/splitio/tests/test_segments.py deleted file mode 100644 index abeaae32..00000000 --- a/splitio/tests/test_segments.py +++ /dev/null @@ -1,396 +0,0 @@ -"""Unit tests for the segments module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase - -from splitio.segments import (InMemorySegment, SelfRefreshingSegmentFetcher, SelfRefreshingSegment, - SegmentChangeFetcher, ApiSegmentChangeFetcher, - CacheBasedSegmentFetcher, CacheBasedSegment) -from splitio.tests.utils import MockUtilsMixin - - -class InMemorySegmentTests(TestCase): - def setUp(self): - self.some_name = 'some_name' - self.some_key_set = ['user_id_1', 'user_id_2', 'user_id_3'] - self.key_set_mock = mock.MagicMock() - self.some_key = 'some_key' - - def test_empty_segment_by_default(self): - """Tests that the segments are empty by default""" - segment = InMemorySegment(self.some_name) - self.assertEqual(0, len(segment._key_set)) - - def test_key_set_is_initialized(self): - """Tests that the segments can be initialized to a specific key_set""" - segment = InMemorySegment(self.some_name, key_set=self.some_key_set) - self.assertSetEqual(set(self.some_key_set), segment._key_set) - - def test_contains_calls_in(self): - """Tests that the segments can be initialized to a specific key_set""" - segment = InMemorySegment(self.some_name) - segment._key_set = self.key_set_mock - - segment.contains(self.some_key) - - self.key_set_mock.__contains__.assert_called_once_with(self.some_key) - - -class SelfRefreshingSegmentFetcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_name = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.some_max_workers = 5 - self.some_segment = mock.MagicMock() - - self.segment_change_fetcher_mock = mock.MagicMock() - self.self_refreshing_segment_mock = self.patch('splitio.segments.SelfRefreshingSegment') - - self.segment_fetcher = SelfRefreshingSegmentFetcher(self.segment_change_fetcher_mock, - interval=self.some_interval, - max_workers=self.some_max_workers) - self.segments_mock = mock.MagicMock() - self.segment_fetcher._segments = self.segments_mock - - def test_cached_segment_are_returned(self): - """Tests that if a segment is cached, it is returned""" - self.segments_mock.__contains__.return_value = True - self.segments_mock.__getitem__.return_value = self.some_segment - - segment = self.segment_fetcher.fetch(self.some_name) - - self.assertEqual(self.some_segment, segment) - - def test_if_segment_is_cached_no_new_segments_are_created(self): - """Tests that if a segment is cached no calls to the segment constructor are made""" - self.segments_mock.__contains__.return_value = True - self.segments_mock.__getitem__.return_value = self.some_segment - - self.segment_fetcher.fetch(self.some_name) - - self.self_refreshing_segment_mock.assert_not_called() - - def test_if_segment_is_not_cached_constructor_is_called(self): - """Tests that if a segment is not cached the SelfRefreshingSegment constructor is called""" - self.segments_mock.__contains__.return_value = False - - self.segment_fetcher.fetch(self.some_name) - - self.self_refreshing_segment_mock.assert_called_once_with(self.some_name, - self.segment_change_fetcher_mock, - self.segment_fetcher._executor, - self.some_interval) - - def test_if_segment_is_not_cached_new_segment_inserted_on_cache(self): - """Tests that if a segment is not cached the a new segment is inserted into the cache""" - self.segments_mock.__contains__.return_value = False - - self.segment_fetcher.fetch(self.some_name) - - self.segments_mock.__setitem__.assert_called_once_with( - self.some_name, self.self_refreshing_segment_mock.return_value) - - def test_new_segment_is_returned(self): - """Tests that the newly created segment is returned""" - self.segments_mock.__contains__.return_value = False - - segment = self.segment_fetcher.fetch(self.some_name) - - self.assertEqual(self.self_refreshing_segment_mock.return_value, segment) - - -class SelfRefreshingSegmentTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_name = mock.MagicMock() - self.some_segment_change_fetcher = mock.MagicMock() - self.some_executor = mock.MagicMock() - self.some_interval = mock.MagicMock() - - self.rlock_mock = self.patch('splitio.segments.RLock') - - self.segment = SelfRefreshingSegment(self.some_name, self.some_segment_change_fetcher, - self.some_executor, self.some_interval) - self.refresh_segment_mock = self.patch_object(self.segment, 'refresh_segment') - self.timer_refresh_mock = self.patch_object(self.segment, '_timer_refresh') - - def test_greedy_by_default(self): - """Tests that _greedy is set to True by default""" - - self.assertTrue(self.segment._greedy) - - def test_start_calls_timer_refresh_if_not_already_started(self): - """Tests that start calls _refresh_timer if it hasn't already been started""" - self.segment._stopped = True - - self.segment.start() - - self.timer_refresh_mock.assert_called_once_with() - - def test_start_sets_stopped_to_false(self): - """Tests that start sets stopped to False if it hasn't been started""" - self.segment._stopped = True - - self.segment.start() - - self.assertFalse(self.segment.stopped) - - def test_start_doesnt_call_timer_refresh_if_already_started(self): - """Tests that start doesn't call _refresh_timer if it has already been started""" - self.segment._stopped = False - - self.segment.start() - - self.timer_refresh_mock.assert_not_called() - - -class SelfRefreshingSegmentRefreshSegmentTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_name = mock.MagicMock() - self.some_segment_change_fetcher = mock.MagicMock() - self.some_segment_change_fetcher.fetch.side_effect = [ # Two updates - { - 'name': 'some_name', - 'added': ['user_id_6'], - 'removed': ['user_id_1', 'user_id_2'], - 'since': -1, - 'till': 1 - }, - { - 'name': 'some_name', - 'added': ['user_id_7'], - 'removed': ['user_id_4'], - 'since': 1, - 'till': 2 - }, - { - 'name': 'some_name', - 'added': [], - 'removed': [], - 'since': 2, - 'till': 2 - } - ] - self.some_executor = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.some_key_set = frozenset(['user_id_1', 'user_id_2', 'user_id_3', - 'user_id_4', 'user_id_5']) - - self.segment = SelfRefreshingSegment(self.some_name, self.some_segment_change_fetcher, - self.some_executor, self.some_interval, - key_set=self.some_key_set) - - def test_refreshes_key_set_with_all_changes_on_greedy(self): - """ - Tests that refresh_segment updates the key set properly consuming all changes if greedy is - set - """ - self.segment.refresh_segment() - - self.assertSetEqual( - {'user_id_3', 'user_id_5', 'user_id_6', 'user_id_7'}, - self.segment._key_set - ) - - def test_refreshes_key_set_with_all_changes_on_non_greedy(self): - """ - Tests that refresh_segment updates the key set properly consuming all changes if greedy is - not set - """ - self.segment._greedy = False - - self.segment.refresh_segment() - - self.assertSetEqual( - {'user_id_3', 'user_id_4', 'user_id_5', 'user_id_6'}, - self.segment._key_set - ) - - def test_key_set_is_not_updated_if_no_changes_were_received(self): - """ - Tests that refresh_segment doesn't update key set if no changes are received from the - server - """ - original_key_set = self.segment._key_set - - self.segment._change_number = 2 - self.some_segment_change_fetcher.fetch.side_effect = [ - { - 'name': 'some_name', - 'added': [], - 'removed': [], - 'since': 2, - 'till': 2 - } - ] - - self.segment.refresh_segment() - - self.assertEqual( - original_key_set, - self.segment._key_set - ) - - def test_updates_change_number(self): - """ - Tests that refresh_segment updates the change number with the last "till" value from the - response of the segment change fetcher - """ - self.segment.refresh_segment() - - self.assertEqual( - 2, - self.segment._change_number - ) - - -class SelfRefreshingSegmentTimerRefreshTests(TestCase, MockUtilsMixin): - def setUp(self): - self.timer_mock = self.patch('splitio.segments.Timer') - - self.some_name = mock.MagicMock() - self.some_segment_change_fetcher = mock.MagicMock() - self.some_executor = mock.MagicMock() - self.some_interval = mock.MagicMock() - self.segment = SelfRefreshingSegment(self.some_name, self.some_segment_change_fetcher, - self.some_executor, self.some_interval) - self.segment._stopped = False - - def test_calls_executor_submit_if_not_stopped(self): - """Tests that if the segment refresh is not stopped, a call to the executor submit method - is made""" - self.segment._timer_refresh() - - self.some_executor.submit.assert_called_once_with(self.segment.refresh_segment) - - def test_new_timer_created_if_not_stopped(self): - """Tests that if the segment refresh is not stopped, a new Timer is created and started""" - self.segment._timer_refresh() - - self.timer_mock.assert_called_once_with(self.segment._interval.return_value, - self.segment._timer_refresh) - self.timer_mock.return_value.start.assert_called_once_with() - - def test_new_timer_created_if_not_stopped_with_random_interval(self): - """Tests that if the segment refresh is not stopped, a new Timer is created and started - calling the interval""" - self.segment._timer_refresh() - - self.timer_mock.assert_called_once_with(self.some_interval.return_value, - self.segment._timer_refresh) - self.timer_mock.return_value.start.assert_called_once_with() - - def test_doesnt_call_executor_submit_if_stopped(self): - """Tests that if the segment refresh is stopped, no call to the executor submit method is - made""" - self.segment._stopped = True - self.segment._timer_refresh() - - self.some_executor.submit.assert_not_called() - - def test_new_timer_not_created_if_stopped(self): - """Tests that if the segment refresh is stopped, no new Timer is created""" - self.segment._stopped = True - self.segment._timer_refresh() - - self.timer_mock.assert_called() - - -class SegmentChangeFetcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_name = mock.MagicMock() - self.some_since = mock.MagicMock() - - self.segment_change_fetcher = SegmentChangeFetcher() - - self.fetch_from_backend_mock = self.patch_object(self.segment_change_fetcher, - 'fetch_from_backend') - - def test_fetch_calls_fetch_from_backend(self): - """Tests that fetch calls fetch_from_backend""" - self.segment_change_fetcher.fetch(self.some_name, self.some_since) - - self.fetch_from_backend_mock.assert_called_once_with(self.some_name, self.some_since) - - def test_fetch_doesnt_raise_exceptions(self): - """ - Tests that if fetch_from_backend raises an exception, no exception is raised from fetch - """ - self.fetch_from_backend_mock.side_effect = Exception() - - try: - self.segment_change_fetcher.fetch(self.some_name, self.some_since) - except: - self.fail('Unexpected exception raised') - - def test_returns_empty_segment_if_backend_raises_an_exception(self): - """ - Tests that if fetch_from_backend raises an exception, an empty segment change is returned - """ - self.fetch_from_backend_mock.side_effect = Exception() - - segment_change = self.segment_change_fetcher.fetch(self.some_name, self.some_since) - - self.assertDictEqual( - { - 'name': self.some_name, - 'added': [], - 'removed': [], - 'since': self.some_since, - 'till': self.some_since - }, - segment_change - ) - - -class ApiSegmentChangeFetcherTests(TestCase): - def setUp(self): - self.some_name = mock.MagicMock() - self.some_since = mock.MagicMock() - self.some_api = mock.MagicMock() - - self.segment_change_fetcher = ApiSegmentChangeFetcher(self.some_api) - - def test_fetch_from_backend_cals_api_segment_changes(self): - """Tests that fetch_from_backend calls segment_changes on the api""" - - self.segment_change_fetcher.fetch_from_backend(self.some_name, self.some_since) - - self.some_api.segment_changes.assert_called_once_with(self.some_name, self.some_since) - - -class CacheBasedSegmentFetcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_segment_name = mock.MagicMock() - self.some_segment_cache = mock.MagicMock() - self.segment_fetcher = CacheBasedSegmentFetcher(self.some_segment_cache) - - def test_fetch_creates_cache_based_segment(self): - segment = self.segment_fetcher.fetch(self.some_segment_name) - self.assertIsInstance(segment, CacheBasedSegment) - self.assertEqual(self.some_segment_cache, segment._segment_cache) - self.assertEqual(self.some_segment_name, segment.name) - - -class CacheBasedSegmentTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_name = mock.MagicMock() - self.some_segment_cache = mock.MagicMock() - self.segment = CacheBasedSegment(self.some_name, self.some_segment_cache) - - def test_contains_calls_segment_cache_is_in_segment(self): - """Test that contains calls segment_cache is_in_segment method""" - self.segment.contains(self.some_key) - self.some_segment_cache.is_in_segment.assert_called_once_with(self.some_name, - self.some_key) - - def test_contains_returns_segment_cache_is_in_segment_results(self): - """Test that contains returns the result of calling segment_cache is_in_segment method""" - self.assertEqual(self.some_segment_cache.is_in_segment.return_value, - self.segment.contains(self.some_key)) diff --git a/splitio/tests/test_splits.py b/splitio/tests/test_splits.py deleted file mode 100644 index 4702fb1c..00000000 --- a/splitio/tests/test_splits.py +++ /dev/null @@ -1,1123 +0,0 @@ -"""Unit tests for the transformers module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase - -import json -from splitio.splits import (InMemorySplitFetcher, SelfRefreshingSplitFetcher, SplitChangeFetcher, - ApiSplitChangeFetcher, SplitParser, AllKeysSplit, - CacheBasedSplitFetcher, HashAlgorithm, ConditionType) -from splitio.matchers import (AndCombiner, AllKeysMatcher, UserDefinedSegmentMatcher, - WhitelistMatcher, AttributeMatcher) -from splitio.tests.utils import MockUtilsMixin -from os.path import join, dirname -from splitio.hashfns import _murmur_hash, get_hash_fn -from splitio.hashfns.legacy import legacy_hash -from splitio.redis_support import get_redis, RedisSegmentCache, RedisSplitParser -from splitio.uwsgi import get_uwsgi, UWSGISegmentCache, UWSGISplitParser -from splitio.clients import Client -from splitio.brokers import RedisBroker -from splitio.splitters import Splitter - - -class InMemorySplitFetcherTests(TestCase): - def setUp(self): - self.some_feature = mock.MagicMock() - self.some_splits = mock.MagicMock() - self.some_splits.values.return_value.__iter__.return_value = [mock.MagicMock(), - mock.MagicMock(), - mock.MagicMock()] - - self.fetcher = InMemorySplitFetcher(self.some_splits) - - def test_fetch_calls_get_on_splits(self): - """Test that fetch calls get on splits""" - self.fetcher.fetch(self.some_feature) - - self.some_splits.get.assert_called_once_with(self.some_feature) - - def test_fetch_returns_result_of_get(self): - """Test that fetch calls returns the result of calling fetch""" - self.assertEqual(self.some_splits.get.return_value, self.fetcher.fetch(self.some_feature)) - - def test_fetch_all_calls_values_on_splits(self): - """Test that fetch_all calls values on splits""" - self.fetcher.fetch_all() - - self.some_splits.values.assert_called_once_with() - - def test_fetch_all_returns_list_of_splits_values(self): - """Test that fetch_all returns a list of values of splits""" - self.assertListEqual(self.some_splits.values.return_value.__iter__.return_value, - self.fetcher.fetch_all()) - - -class SelfRefreshingSplitFetcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.rlock_mock = self.patch('splitio.splits.RLock') - self.some_split_change_fetcher = mock.MagicMock() - self.some_split_parser = mock.MagicMock() - - self.fetcher = SelfRefreshingSplitFetcher(self.some_split_change_fetcher, - self.some_split_parser) - self.refresh_splits_mock = self.patch_object(self.fetcher, 'refresh_splits') - self.timer_refresh_mock = self.patch_object(self.fetcher, '_timer_refresh') - - def test_start_calls_timer_refresh_if_stopped(self): - """Tests that if stopped is True, start calls _timer_refresh""" - self.fetcher.stopped = True - self.fetcher.start() - - self.timer_refresh_mock.assert_called_once_with() - - def test_start_sets_stopped_False_if_stopped(self): - """Tests that if stopped is True, start sets it to True""" - self.fetcher.stopped = True - self.fetcher.start() - - self.assertFalse(self.fetcher.stopped) - - def test_start_doesnt_call_timer_refresh_if_not_stopped(self): - """Tests that if stopped is false, start doesn't call _timer_refresh""" - self.fetcher.stopped = False - self.fetcher.start() - - self.timer_refresh_mock.assert_not_called() - - -class SelfRefreshingSplitFetcherUpdateSplitsFromChangeFetcherResponseTests(TestCase, - MockUtilsMixin): - def setUp(self): - self.rlock_mock = self.patch('splitio.splits.RLock') - - self.some_feature = mock.MagicMock() - self.some_splits = mock.MagicMock() - self.some_split_change_fetcher = mock.MagicMock() - self.some_split_parser = mock.MagicMock() - - self.fetcher = SelfRefreshingSplitFetcher(self.some_split_change_fetcher, - self.some_split_parser, splits=self.some_splits) - - self.split_to_add = {'status': 'ACTIVE', 'name': 'split_to_add'} - self.split_to_remove = {'status': 'ARCHIVED', 'name': 'split_to_remove'} - self.some_response = { - 'splits': [self.split_to_add, self.split_to_remove] - } - - def test_pop_is_called_on_removed_split(self): - """Tests that pop is called on splits for removed features""" - self.fetcher._update_splits_from_change_fetcher_response(self.some_response) - - self.some_splits.pop.assert_called_once_with(self.split_to_remove['name'], None) - - def test_split_parser_parse_is_called_for_split_to_add(self): - """Tests that parse is called on split_parser with the split to add with the default value - for block_until_ready""" - self.fetcher._update_splits_from_change_fetcher_response(self.some_response) - - self.some_split_parser.parse.assert_called_once_with(self.split_to_add, - block_until_ready=False) - - def test_split_parser_parse_is_called_for_split_to_add_with_block_until_ready(self): - """Tests that parse is called on split_parser with the split to add with the supplied value - for block_until_ready""" - some_block_until_ready = mock.MagicMock() - self.fetcher._update_splits_from_change_fetcher_response( - self.some_response, block_until_ready=some_block_until_ready) - - self.some_split_parser.parse.assert_called_once_with( - self.split_to_add, block_until_ready=some_block_until_ready) - - def test_setitem_is_called_with_parsed_split_to_add(self): - """Tests that pop is called on splits for removed features""" - self.fetcher._update_splits_from_change_fetcher_response(self.some_response) - - self.some_splits.__setitem__.assert_called_once_with( - self.split_to_add['name'], self.some_split_parser.parse.return_value) - - def test_pop_is_called_if_split_parse_fails(self): - """Tests that pop is called on splits if parse returns None""" - self.some_split_parser.parse.return_value = None - - self.fetcher._update_splits_from_change_fetcher_response( - {'splits': [self.split_to_add]}) - - self.some_splits.pop.assert_called_once_with(self.split_to_add['name'], None) - - -class SelfRefreshingSplitFetcherRefreshSplitsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.rlock_mock = self.patch('splitio.splits.RLock') - - self.some_feature = mock.MagicMock() - self.some_splits = mock.MagicMock() - self.some_split_change_fetcher = mock.MagicMock() - self.some_split_parser = mock.MagicMock() - - self.response_0 = {'till': 1, 'splits': [mock.MagicMock()]} - self.response_1 = {'till': 2, 'splits': [mock.MagicMock()]} - self.response_2 = {'till': 2, 'splits': []} - - self.some_split_change_fetcher.fetch.side_effect = [ - self.response_0, - self.response_1, - self.response_2 - ] - - self.fetcher = SelfRefreshingSplitFetcher(self.some_split_change_fetcher, - self.some_split_parser, - change_number=-1, - splits=self.some_splits) - self.update_splits_from_change_fetcher_response_mock = self.patch_object( - self.fetcher, '_update_splits_from_change_fetcher_response') - - def test_calls_split_change_fetcher_until_change_number_ge_till_if_greedy(self): - """ - Tests that if greedy is True _refresh_splits calls fetch on split_change_fetcher until - change_number >= till - """ - self.fetcher.refresh_splits() - - self.assertListEqual( - [mock.call(-1), mock.call(1), mock.call(2)], - self.some_split_change_fetcher.fetch.call_args_list) - - def test_calls_split_change_fetcher_once_if_non_greedy(self): - """ - Tests that if greedy is False _refresh_splits calls fetch on split_change_fetcher once - """ - self.fetcher._greedy = False - self.fetcher.refresh_splits() - self.some_split_change_fetcher.fetch.assert_called_once_with(-1) - - def test_calls_update_splits_from_change_fetcher_response_on_each_response_if_greedy(self): - """ - Tests that if greedy is True _refresh_splits calls - _update_splits_from_change_fetcher_response on all responses from split change fetcher - with the default value for block_until_ready - """ - self.fetcher.refresh_splits() - self.assertListEqual( - [mock.call(self.response_0, block_until_ready=False), - mock.call(self.response_1, block_until_ready=False)], - self.update_splits_from_change_fetcher_response_mock.call_args_list) - - def test_calls_update_splits_from_change_fetcher_response_on_each_response_greedy_block(self): - """ - Tests that if greedy is True _refresh_splits calls - _update_splits_from_change_fetcher_response on all responses from split change fetcher - with the supplied value for block_until_ready - """ - some_block_until_ready = mock.MagicMock() - self.fetcher.refresh_splits(block_until_ready=some_block_until_ready) - self.assertListEqual( - [mock.call(self.response_0, block_until_ready=some_block_until_ready), - mock.call(self.response_1, block_until_ready=some_block_until_ready)], - self.update_splits_from_change_fetcher_response_mock.call_args_list) - - def test_calls_update_splits_from_change_fetcher_response_once_if_non_greedy(self): - """ - Tests that if greedy is False _refresh_splits calls - _update_splits_from_change_fetcher_response once with the default value for - block_until_ready - """ - self.fetcher._greedy = False - self.fetcher.refresh_splits() - self.update_splits_from_change_fetcher_response_mock.assert_called_once_with( - self.response_0, block_until_ready=False) - - def test_calls_update_splits_from_change_fetcher_response_once_if_non_greedy_blocking(self): - """ - Tests that if greedy is False _refresh_splits calls - _update_splits_from_change_fetcher_response once with the supplied value for - block_until_ready - """ - some_block_until_ready = mock.MagicMock() - self.fetcher._greedy = False - self.fetcher.refresh_splits(block_until_ready=some_block_until_ready) - self.update_splits_from_change_fetcher_response_mock.assert_called_once_with( - self.response_0, block_until_ready=some_block_until_ready) - - def test_sets_change_number_to_latest_value_of_till_response(self): - """ - Tests that _refresh_splits sets change_number to the largest value of "till" in the - response of the split_change_fetcher fetch response. - """ - self.fetcher.refresh_splits() - self.assertEqual(2, self.fetcher.change_number) - - def test_stop_set_true_on_exception(self): - """ - Tests that stopped is set to True if an exception is raised - """ - self.some_split_change_fetcher.fetch.side_effect = [ - self.response_0, - Exception() - ] - self.fetcher.refresh_splits() - self.assertTrue(self.fetcher.stopped) - - def test_change_number_set_value_till_latest_successful_iteration(self): - """ - Tests that change_number is set to the value of "till" in the latest successful iteration - before an exception is raised - """ - self.some_split_change_fetcher.fetch.side_effect = [ - self.response_0, - Exception() - ] - self.fetcher.refresh_splits() - self.assertEqual(1, self.fetcher.change_number) - - -class SelfRefreshingSplitFetcherTimerRefreshTests(TestCase, MockUtilsMixin): - def setUp(self): - self.rlock_mock = self.patch('splitio.splits.RLock') - self.thread_mock = self.patch('splitio.splits.Thread') - self.some_split_change_fetcher = mock.MagicMock() - self.some_split_parser = mock.MagicMock() - self.some_interval = mock.NonCallableMagicMock() - - self.fetcher = SelfRefreshingSplitFetcher(self.some_split_change_fetcher, - self.some_split_parser, - interval=self.some_interval) - self.timer_start_mock = self.patch_object(self.fetcher, '_timer_start') - self.fetcher.stopped = False - - def test_thread_created_and_started_with_refresh_splits(self): - """Tests that _timer_refresh creates and starts a Thread with _refresh_splits target""" - self.fetcher._timer_refresh() - self.thread_mock.assert_called_once_with(target=self.fetcher.refresh_splits) - self.thread_mock.return_value.start.assert_called_once_with() - - def test_calls_timer_start(self): - """Tests that _timer_refresh creates and starts a Timer with _timer_refresh target""" - self.fetcher._timer_refresh() - self.timer_start_mock.assert_called_once_with() - - def test_no_thread_created_if_stopped(self): - """Tests that _timer_refresh doesn't create a Thread if it is stopped""" - self.fetcher.stopped = True - self.fetcher._timer_refresh() - self.thread_mock.assert_not_called() - - def test_timer_start_not_called_if_stopped(self): - """Tests that _timer_refresh doesn't call start_tiemer if it is stopped""" - self.fetcher.stopped = True - self.fetcher._timer_refresh() - self.timer_start_mock.assert_called() - - def test_timer_start_called_if_thread_raises_exception(self): - """ - Tests that _timer_refresh calls timer_start even if the _refresh_splits thread - setup raises an exception - """ - self.thread_mock.return_value = None - self.thread_mock.side_effect = Exception() - self.fetcher._timer_refresh() - self.timer_start_mock.assert_called_once_with() - - -class SplitChangeFetcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_since = mock.MagicMock() - self.fetcher = SplitChangeFetcher() - self.fetch_from_backend_mock = self.patch_object(self.fetcher, 'fetch_from_backend') - - def test_fetch_calls_fetch_from_backend(self): - """Tests that fetch calls fetch_from_backend""" - self.fetcher.fetch(self.some_since) - self.fetch_from_backend_mock.assert_called_once_with(self.some_since) - - def test_fetch_returns_fetch_from_backend_result(self): - """Tests that fetch returns fetch_from_backend call return""" - self.assertEqual(self.fetch_from_backend_mock.return_value, - self.fetcher.fetch(self.some_since)) - - def test_fetch_doesnt_raise_an_exception_fetch_from_backend_raises_one(self): - """Tests that fetch doesn't raise an exception if fetch_from_backend does""" - self.fetch_from_backend_mock.side_effect = Exception() - - try: - self.fetcher.fetch(self.some_since) - except Exception: - self.fail('Unexpected exception raised') - - def test_fetch_returns_empty_response_if_fetch_from_backend_raises_an_exception(self): - """Tests that fetch returns fetch_from_backend call return""" - self.fetch_from_backend_mock.side_effect = Exception() - self.assertDictEqual({'since': self.some_since, 'till': self.some_since, 'splits': []}, - self.fetcher.fetch(self.some_since)) - - -class ApiSplitChangeFetcherTests(TestCase): - def setUp(self): - self.some_since = mock.MagicMock() - self.some_api = mock.MagicMock() - self.fetcher = ApiSplitChangeFetcher(self.some_api) - - def test_fetch_from_backend_calls_api_split_changes(self): - """Tests that fetch_from_backend calls api split_changes""" - self.fetcher.fetch_from_backend(self.some_since) - self.some_api.split_changes.assert_called_once_with(self.some_since) - - def test_fetch_from_backend_returns_api_split_changes_result(self): - """Tests that fetch_from_backend returns api split_changes call result""" - self.assertEqual(self.some_api.split_changes.return_value, - self.fetcher.fetch_from_backend(self.some_since)) - - def test_fetch_from_backend_raises_exception_if_api_split_changes_does(self): - """Tests that fetch_from_backend raises an exception if split_changes does""" - self.some_api.split_changes.side_effect = Exception() - with self.assertRaises(Exception): - self.fetcher.fetch_from_backend(self.some_since) - - -class SplitParserParseTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_split = mock.MagicMock() - self.some_segment_fetcher = mock.MagicMock() - - self.parser = SplitParser(self.some_segment_fetcher) - self.internal_parse_mock = self.patch_object(self.parser, '_parse') - - def test_parse_calls_internal_parse(self): - """Tests that parse calls _parse with block_until_ready as False""" - self.parser.parse(self.some_split) - self.internal_parse_mock.assert_called_once_with(self.some_split, - block_until_ready=False) - - def test_parse_calls_internal_parse_with_block_until_ready(self): - """Tests that parse calls _parse passing the value of block_until_ready""" - self.parser.parse(self.some_split, block_until_ready=True) - self.internal_parse_mock.assert_called_once_with(self.some_split, - block_until_ready=True) - - def test_parse_returns_none_if_internal_parse_raises_an_exception(self): - """Tests that parse returns None if _parse raises an exception""" - self.internal_parse_mock.side_effect = Exception() - self.assertIsNone(self.parser.parse(self.some_split)) - - -class SplitParserInternalParseTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_split = mock.MagicMock() - self.some_segment_fetcher = mock.MagicMock() - self.some_block_until_ready = mock.MagicMock() - - self.partition_mock = self.patch('splitio.splits.Partition') - self.partition_mock_side_effect = [mock.MagicMock() for _ in range(3)] - self.partition_mock.side_effect = self.partition_mock_side_effect - - self.condition_mock = self.patch('splitio.splits.Condition') - self.condition_mock_side_effect = [mock.MagicMock() for _ in range(2)] - self.condition_mock.side_effect = self.condition_mock_side_effect - - self.parser = SplitParser(self.some_segment_fetcher) - self.parse_split_mock = self.patch_object(self.parser, '_parse_split') - self.parse_matcher_group_mock = self.patch_object(self.parser, '_parse_matcher_group') - self.parse_matcher_group_mock_side_effect = [mock.MagicMock() for _ in range(2)] - self.parse_matcher_group_mock.side_effect = self.parse_matcher_group_mock_side_effect - - self.partition_0 = {'treatment': mock.MagicMock(), 'size': mock.MagicMock()} - self.partition_1 = {'treatment': mock.MagicMock(), 'size': mock.MagicMock()} - self.partition_2 = {'treatment': mock.MagicMock(), 'size': mock.MagicMock()} - - self.matcher_group_0 = mock.MagicMock() - self.matcher_group_1 = mock.MagicMock() - - self.label_0 = mock.MagicMock() - self.label_1 = mock.MagicMock() - - self.some_split = { - 'status': 'ACTIVE', - 'name': mock.MagicMock(), - 'seed': mock.MagicMock(), - 'killed': mock.MagicMock(), - 'defaultTreatment': mock.MagicMock(), - 'conditions': [ - { - 'matcherGroup': self.matcher_group_0, - 'partitions': [ - - self.partition_0 - ], - 'label': self.label_0 - }, - { - 'matcherGroup': self.matcher_group_1, - 'partitions': [ - self.partition_1, - self.partition_2 - ], - 'label': self.label_1 - } - ] - } - - def test_returns_none_if_status_is_not_active(self): - """Tests that _parse returns None if split is not ACTIVE""" - self.assertIsNone(self.parser._parse({'status': 'ARCHIVED'})) - - def test_creates_partition_on_each_condition_partition(self): - """Test that _parse calls Partition constructor on each partition""" - self.parser._parse(self.some_split) - - self.assertListEqual( - [mock.call(self.partition_0['treatment'], self.partition_0['size']), - mock.call(self.partition_1['treatment'], self.partition_1['size']), - mock.call(self.partition_2['treatment'], self.partition_2['size'])], - self.partition_mock.call_args_list - ) - - def test_calls_parse_matcher_group_on_each_matcher_group(self): - """Tests that _parse calls _parse_matcher_group on each matcher group with the default - value for block_until_ready""" - self.parser._parse(self.some_split) - - self.assertListEqual( - [mock.call(self.parse_split_mock.return_value, self.matcher_group_0, - block_until_ready=False), - mock.call(self.parse_split_mock.return_value, self.matcher_group_1, - block_until_ready=False)], - self.parse_matcher_group_mock.call_args_list - ) - - def test_calls_parse_matcher_group_on_each_matcher_group_with_block_until_ready(self): - """Tests that _parse calls _parse_matcher_group on each matcher group with the passed - value for block_until_ready""" - some_block_until_ready = mock.MagicMock() - self.parser._parse(self.some_split, block_until_ready=some_block_until_ready) - - self.assertListEqual( - [mock.call(self.parse_split_mock.return_value, self.matcher_group_0, - block_until_ready=some_block_until_ready), - mock.call(self.parse_split_mock.return_value, self.matcher_group_1, - block_until_ready=some_block_until_ready)], - self.parse_matcher_group_mock.call_args_list - ) - - def test_creates_condition_on_each_condition(self): - """Tests that _parse calls Condition constructor on each condition""" - self.parser._parse(self.some_split) - - self.assertListEqual( - [ - mock.call( - self.parse_matcher_group_mock_side_effect[0], - [self.partition_mock_side_effect[0]], - self.label_0, - ConditionType.WHITELIST - ), - mock.call( - self.parse_matcher_group_mock_side_effect[1], - [ - self.partition_mock_side_effect[1], - self.partition_mock_side_effect[2] - ], - self.label_1, - ConditionType.WHITELIST - ) - ], - self.condition_mock.call_args_list - ) - - def test_calls_parse_split(self): - """Tests that _parse calls _parse_split""" - self.parser._parse(self.some_split, block_until_ready=self.some_block_until_ready) - - self.parse_split_mock.assert_called_once_with( - self.some_split, block_until_ready=self.some_block_until_ready) - - -class SplitParserParseMatcherGroupTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_partial_split = mock.MagicMock() - self.some_segment_fetcher = mock.MagicMock() - - self.combining_matcher_mock = self.patch('splitio.splits.CombiningMatcher') - self.parser = SplitParser(self.some_segment_fetcher) - self.parse_matcher_mock = self.patch_object(self.parser, '_parse_matcher') - self.parse_matcher_side_effect = [mock.MagicMock() for _ in range(2)] - self.parse_matcher_mock.side_effect = self.parse_matcher_side_effect - self.parse_combiner_mock = self.patch_object(self.parser, '_parse_combiner') - - self.some_matchers = [mock.MagicMock(), mock.MagicMock()] - self.some_matcher_group = { - 'matchers': self.some_matchers, - 'combiner': mock.MagicMock() - } - - def test_calls_parse_matcher_on_each_matcher(self): - """Tests that _parse_matcher_group calls _parse_matcher on each matcher with the default - value for block_until_ready""" - self.parser._parse_matcher_group(self.some_partial_split, self.some_matcher_group) - self.assertListEqual([mock.call(self.some_partial_split, self.some_matchers[0], - block_until_ready=False), - mock.call(self.some_partial_split, self.some_matchers[1], - block_until_ready=False)], - self.parse_matcher_mock.call_args_list) - - def test_calls_parse_matcher_with_block_until_ready_parameter(self): - """Tests that _parse_matcher_group calls _parse_matcher on each matcher""" - some_block_until_ready = mock.MagicMock - self.parser._parse_matcher_group(self.some_partial_split, self.some_matcher_group, - block_until_ready=some_block_until_ready) - self.assertListEqual([mock.call(self.some_partial_split, self.some_matchers[0], - block_until_ready=some_block_until_ready), - mock.call(self.some_partial_split, self.some_matchers[1], - block_until_ready=some_block_until_ready)], - self.parse_matcher_mock.call_args_list) - - def test_calls_parse_combiner_on_combiner(self): - """Tests that _parse_matcher_group calls _parse_combiner on combiner""" - self.parser._parse_matcher_group(self.some_partial_split, self.some_matcher_group) - self.parse_combiner_mock.assert_called_once_with(self.some_matcher_group['combiner']) - - def test_creates_combining_matcher(self): - """Tests that _parse_matcher_group calls CombiningMatcher constructor""" - self.parser._parse_matcher_group(self.some_partial_split, self.some_matcher_group) - self.combining_matcher_mock.assert_called_once_with(self.parse_combiner_mock.return_value, - self.parse_matcher_side_effect) - - def test_returns_combining_matcher(self): - """Tests that _parse_matcher_group returns a CombiningMatcher""" - self.assertEqual(self.combining_matcher_mock.return_value, - self.parser._parse_matcher_group(self.some_partial_split, - self.some_matcher_group)) - - -class SplitParserParseCombinerTests(TestCase): - def setUp(self): - self.some_segment_fetcher = mock.MagicMock() - - self.parser = SplitParser(self.some_segment_fetcher) - - def test_returns_and_combiner(self): - """Tests that _parse_combiner returns an AndCombiner""" - self.assertIsInstance(self.parser._parse_combiner('AND'), AndCombiner) - - def test_raises_exception_on_invalid_combiner(self): - """Tests that _parse_combiner raises an exception on an invalid combiner""" - with self.assertRaises(ValueError): - self.parser._parse_combiner('foobar') - - -class SplitParserMatcherParseMethodsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_partial_split = mock.MagicMock() - self.some_segment_fetcher = mock.MagicMock() - self.some_matcher = mock.MagicMock() - - self.parser = SplitParser(self.some_segment_fetcher) - - self.get_matcher_data_data_type_mock = self.patch_object(self.parser, - '_get_matcher_data_data_type') - self.equal_to_matcher_mock = self.patch('splitio.splits.EqualToMatcher') - self.greater_than_or_equal_to_matcher_mock = self.patch( - 'splitio.splits.GreaterThanOrEqualToMatcher') - self.less_than_or_equal_to_matcher_mock = self.patch( - 'splitio.splits.LessThanOrEqualToMatcher') - self.between_matcher_mock = self.patch( - 'splitio.splits.BetweenMatcher') - - self.some_in_segment_matcher = { - 'matcherType': 'IN_SEGMENT', - 'userDefinedSegmentMatcherData': { - 'segmentName': mock.MagicMock() - } - } - self.some_whitelist_matcher = { - 'matcherType': 'WHITELIST', - 'whitelistMatcherData': { - 'whitelist': mock.MagicMock() - } - } - self.some_equal_to_matcher = self._get_unary_number_matcher('EQUAL_TO') - self.some_greater_than_or_equal_to_matcher = self._get_unary_number_matcher( - 'GREATER_THAN_OR_EQUAL_TO') - self.some_less_than_or_equal_to_matcher = self._get_unary_number_matcher( - 'LESS_THAN_OR_EQUAL_TO') - self.some_between_matcher = { - 'matcherType': 'BETWEEN', - 'betweenMatcherData': { - 'start': mock.MagicMock(), - 'end': mock.MagicMock(), - } - } - - def _get_unary_number_matcher(self, matcher_type): - return { - 'matcherType': matcher_type, - 'unaryNumericMatcherData': { - 'dataType': mock.MagicMock(), - 'value': mock.MagicMock() - } - } - - def test_parse_matcher_all_keys_returns_all_keys_matcher(self): - """Tests that _parser_matcher_all_keys returns an AllKeysMatcher""" - self.assertIsInstance(self.parser._parse_matcher_all_keys(self.some_partial_split, - self.some_matcher), - AllKeysMatcher) - - def test_parse_matcher_in_segment_calls_segment_fetcher_fetch(self): - """Tests that _parse_matcher_in_segment calls segment_fetcher fetch method with default - value for block_until_ready""" - self.parser._parse_matcher_in_segment(self.some_partial_split, - self.some_in_segment_matcher) - self.some_segment_fetcher.fetch.assert_called_once_with( - self.some_in_segment_matcher['userDefinedSegmentMatcherData']['segmentName'], - block_until_ready=False) - - def test_parse_matcher_in_segment_calls_segment_fetcher_fetch_block(self): - """Tests that _parse_matcher_in_segment calls segment_fetcher fetch method with supploed - value for block_until_ready""" - some_block_until_ready = mock.MagicMock() - self.parser._parse_matcher_in_segment(self.some_partial_split, self.some_in_segment_matcher, - block_until_ready=some_block_until_ready) - self.some_segment_fetcher.fetch.assert_called_once_with( - self.some_in_segment_matcher['userDefinedSegmentMatcherData']['segmentName'], - block_until_ready=some_block_until_ready) - - def test_parse_matcher_in_segment_returns_user_defined_segment_matcher(self): - """Tests that _parse_matcher_in_segment calls segment_fetcher fetch method""" - self.assertIsInstance(self.parser._parse_matcher_in_segment(self.some_partial_split, - self.some_in_segment_matcher), - UserDefinedSegmentMatcher) - - def test_parse_matcher_whitelist_returns_whitelist_matcher(self): - """Tests that _parse_matcher_whitelist returns a WhitelistMatcher""" - self.assertIsInstance(self.parser._parse_matcher_whitelist(self.some_partial_split, - self.some_whitelist_matcher), - WhitelistMatcher) - - def test_parse_matcher_equal_to_calls_equal_to_matcher_for_data_type(self): - """Tests that _parse_matcher_equal_to calls EqualToMatcher.for_data_type""" - self.parser._parse_matcher_equal_to(self.some_partial_split, self.some_equal_to_matcher) - self.equal_to_matcher_mock.for_data_type.assert_called_once_with( - self.get_matcher_data_data_type_mock.return_value, - self.some_equal_to_matcher['unaryNumericMatcherData']['value']) - - def test_parse_matcher_equal_to_returns_equal_to_matcher(self): - """ - Tests that _parse_matcher_equal_to returns the result of calling - EqualToMatcher.for_data_type - """ - self.assertEqual(self.equal_to_matcher_mock.for_data_type.return_value, - self.parser._parse_matcher_equal_to(self.some_partial_split, - self.some_equal_to_matcher)) - - def test_parse_matcher_greater_than_or_equal_to_calls_equal_to_matcher_for_data_type(self): - """ - Tests that _parse_matcher_greater_than_or_equal_to calls - GreaterThanOrEqualToMatcher.for_data_type - """ - self.parser \ - ._parse_matcher_greater_than_or_equal_to(self.some_partial_split, - self.some_greater_than_or_equal_to_matcher) - self.greater_than_or_equal_to_matcher_mock.for_data_type.assert_called_once_with( - self.get_matcher_data_data_type_mock.return_value, - self.some_greater_than_or_equal_to_matcher['unaryNumericMatcherData']['value']) - - def test_parse_matcher_greater_than_or_equal_to_returns_equal_to_matcher(self): - """ - Tests that _parse_matcher_greater_than_or_equal_to returns the result of calling - GreaterThanOrEqualToMatcher.for_data_type - """ - self.assertEqual( - self.greater_than_or_equal_to_matcher_mock.for_data_type.return_value, - self.parser._parse_matcher_greater_than_or_equal_to( - self.some_partial_split, - self.some_greater_than_or_equal_to_matcher)) - - def test_parse_matcher_less_than_or_equal_to_calls_equal_to_matcher_for_data_type(self): - """ - Tests that _parse_matcher_less_than_or_equal_to calls - LessThanOrEqualToMatcher.for_data_type - """ - self.parser._parse_matcher_less_than_or_equal_to( - self.some_partial_split, - self.some_less_than_or_equal_to_matcher) - self.less_than_or_equal_to_matcher_mock.for_data_type.assert_called_once_with( - self.get_matcher_data_data_type_mock.return_value, - self.some_less_than_or_equal_to_matcher['unaryNumericMatcherData']['value']) - - def test_parse_matcher_less_than_or_equal_to_returns_equal_to_matcher(self): - """ - Tests that _parse_matcher_less_than_or_equal_to returns the result of calling - LessThanOrEqualToMatcher.for_data_type - """ - self.assertEqual( - self.less_than_or_equal_to_matcher_mock.for_data_type.return_value, - self.parser._parse_matcher_less_than_or_equal_to( - self.some_partial_split, - self.some_less_than_or_equal_to_matcher)) - - def test_parse_matcher_between_calls_between_matcher_for_data_type(self): - """Tests that _parse_matcher_between calls BetweenMatcher.for_data_type""" - self.parser._parse_matcher_between(self.some_partial_split, self.some_between_matcher) - self.between_matcher_mock.for_data_type.assert_called_once_with( - self.get_matcher_data_data_type_mock.return_value, - self.some_between_matcher['betweenMatcherData']['start'], - self.some_between_matcher['betweenMatcherData']['end']) - - def test_parse_matcher_between_returns_between_matcher(self): - """ - Tests that _parse_matcher_between returns the result of calling - BetweenMatcher.for_data_type - """ - self.assertEqual(self.between_matcher_mock.for_data_type.return_value, - self.parser._parse_matcher_between(self.some_partial_split, - self.some_between_matcher)) - - -class SplitParserParseMatcherTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_partial_split = mock.MagicMock() - self.some_segment_fetcher = mock.MagicMock() - self.some_matcher = mock.MagicMock() - - self.parser = SplitParser(self.some_segment_fetcher) - - self.parse_matcher_all_keys_mock = self.patch_object(self.parser, - '_parse_matcher_all_keys') - self.parse_matcher_in_segment_mock = self.patch_object(self.parser, - '_parse_matcher_in_segment') - self.parse_matcher_whitelist_mock = self.patch_object(self.parser, - '_parse_matcher_whitelist') - self.parse_matcher_equal_to_mock = self.patch_object(self.parser, - '_parse_matcher_equal_to') - self.parse_matcher_greater_than_or_equal_to_mock = self.patch_object( - self.parser, '_parse_matcher_greater_than_or_equal_to') - self.parse_matcher_less_than_or_equal_to_mock = self.patch_object( - self.parser, '_parse_matcher_less_than_or_equal_to') - self.parse_matcher_between_mock = self.patch_object(self.parser, '_parse_matcher_between') - - self.parser._parse_matcher_fake = mock.MagicMock() - - def _get_matcher(self, matcher_type): - return { - 'matcherType': matcher_type, - 'negate': mock.MagicMock(), - 'keySelector': { - 'attribute': mock.MagicMock() - } - } - - def test_calls_parse_matcher_all_keys(self): - """Test that _parse_matcher calls _parse_matcher_all_keys on ALL_KEYS matcher""" - matcher = self._get_matcher('ALL_KEYS') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_all_keys_mock.assert_called_once_with(self.some_partial_split, matcher, - block_until_ready=False) - - def test_calls_parse_matcher_in_segment(self): - """Test that _parse_matcher calls _parse_matcher_in_segment on IN_SEGMENT matcher""" - matcher = self._get_matcher('IN_SEGMENT') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_in_segment_mock.assert_called_once_with(self.some_partial_split, - matcher, block_until_ready=False) - - def test_calls_parse_matcher_whitelist(self): - """Test that _parse_matcher calls _parse_matcher_in_segment on WHITELIST matcher""" - matcher = self._get_matcher('WHITELIST') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_whitelist_mock.assert_called_once_with(self.some_partial_split, matcher, - block_until_ready=False) - - def test_calls_parse_matcher_equal_to(self): - """Test that _parse_matcher calls _parse_matcher_equal_to on EQUAL_TO matcher""" - matcher = self._get_matcher('EQUAL_TO') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_equal_to_mock.assert_called_once_with(self.some_partial_split, matcher, - block_until_ready=False) - - def test_calls_parse_matcher_greater_than_or_equal_to(self): - """ - Test that _parse_matcher calls _parse_matcher_greater_than_or_equal_to on - GREATER_THAN_OR_EQUAL_TO matcher - """ - matcher = self._get_matcher('GREATER_THAN_OR_EQUAL_TO') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_greater_than_or_equal_to_mock.assert_called_once_with( - self.some_partial_split, matcher, block_until_ready=False) - - def test_calls_parse_matcher_less_than_or_equal_to(self): - """ - Test that _parse_matcher calls _parse_matcher_less_than_or_equal_to on - LESS_THAN_OR_EQUAL_TO matcher - """ - matcher = self._get_matcher('LESS_THAN_OR_EQUAL_TO') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_less_than_or_equal_to_mock.assert_called_once_with( - self.some_partial_split, matcher, block_until_ready=False) - - def test_calls_parse_matcher_between(self): - """Test that _parse_matcher calls _parse_between on BETWEEN matcher""" - matcher = self._get_matcher('BETWEEN') - self.parser._parse_matcher(self.some_partial_split, matcher) - self.parse_matcher_between_mock.assert_called_once_with(self.some_partial_split, matcher, - block_until_ready=False) - - def test_raises_exception_if_parse_method_returns_none(self): - """ - Tests that _parse_matcher raises an exception if the specific parse method returns None - """ - self.parser._parse_matcher_fake.return_value = None - with self.assertRaises(ValueError): - self.parser._parse_matcher(self.some_partial_split, self._get_matcher('FAKE')) - - def test_returns_attribute_matcher(self): - """Tests that _parse_matcher returns an AttributeMatcher""" - self.assertIsInstance(self.parser._parse_matcher(self.some_partial_split, - self._get_matcher('FAKE')), - AttributeMatcher) - - -class AllKeysSplitTests(TestCase): - def setUp(self): - self.some_name = mock.MagicMock() - self.some_treatment = mock.MagicMock() - self.split = AllKeysSplit(self.some_name, self.some_treatment) - - def test_single_condition(self): - """Tests that it as a single condition""" - self.assertEqual(1, len(self.split.conditions)) - - def test_condition_as_attribute_matcher(self): - """Tests that the condition is an attribute matcher""" - self.assertIsInstance(self.split.conditions[0].matcher, AttributeMatcher) - - def test_condition_has_single_partition(self): - """Tests that the condition has a single partition""" - self.assertEqual(1, len(self.split.conditions[0].partitions)) - - def test_partition_is_100_percent(self): - """Tests that the partition has a size 100""" - self.assertEqual(100, self.split.conditions[0].partitions[0].size) - - def test_partition_has_treatment(self): - """Tests that the partition has the set treatment""" - self.assertEqual(self.some_treatment, self.split.conditions[0].partitions[0].treatment) - - -class CacheBasedSplitFetcherTests(TestCase): - def setUp(self): - self.some_feature = mock.MagicMock() - self.some_split_cache = mock.MagicMock() - self.split_fetcher = CacheBasedSplitFetcher(split_cache=self.some_split_cache) - - def test_fetch_calls_get_split(self): - """Test that fetch calls get_split on the split cache""" - self.split_fetcher.fetch(self.some_feature) - self.some_split_cache.get_split.assert_called_once_with(self.some_feature) - - def test_fetch_results_get_split_result(self): - """Test that fetch returns the result of calling get split on the cache""" - self.assertEqual(self.some_split_cache.get_split.return_value, - self.split_fetcher.fetch(self.some_feature)) - - -class RedisCacheAlgoFieldTests(TestCase): - def setUp(self): - ''' - ''' - fn = join(dirname(__file__), 'algoSplits.json') - with open(fn, 'r') as flo: - rawData = json.load(flo)['splits'] - self._testData = [ - { - 'body': rawData[0], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - }, - { - 'body': rawData[1], - 'algo': HashAlgorithm.MURMUR, - 'hashfn': _murmur_hash - }, - { - 'body': rawData[2], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - }, - { - 'body': rawData[3], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - }, - { - 'body': rawData[4], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - } - ] - - def testAlgoHandlers(self): - ''' - ''' - redis = get_redis({}) - segment_cache = RedisSegmentCache(redis) - split_parser = RedisSplitParser(segment_cache) - for sp in self._testData: - split = split_parser.parse(sp['body'], True) - self.assertEqual(split.algo, sp['algo']) - self.assertEqual(get_hash_fn(split.algo), sp['hashfn']) - - -class UWSGICacheAlgoFieldTests(TestCase): - def setUp(self): - ''' - ''' - fn = join(dirname(__file__), 'algoSplits.json') - with open(fn, 'r') as flo: - rawData = json.load(flo)['splits'] - self._testData = [ - { - 'body': rawData[0], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - }, - { - 'body': rawData[1], - 'algo': HashAlgorithm.MURMUR, - 'hashfn': _murmur_hash - }, - { - 'body': rawData[2], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - }, - { - 'body': rawData[3], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - }, - { - 'body': rawData[4], - 'algo': HashAlgorithm.LEGACY, - 'hashfn': legacy_hash - } - ] - - def testAlgoHandlers(self): - ''' - ''' - uwsgi = get_uwsgi(True) - segment_cache = UWSGISegmentCache(uwsgi) - split_parser = UWSGISplitParser(segment_cache) - for sp in self._testData: - split = split_parser.parse(sp['body'], True) - self.assertEqual(split.algo, sp['algo']) - self.assertEqual(get_hash_fn(split.algo), sp['hashfn']) - - -class TrafficAllocationTests(TestCase, MockUtilsMixin): - ''' - ''' - - def setUp(self): - ''' - ''' - self.some_config = mock.MagicMock() - redis = get_redis({}) - segment_cache = RedisSegmentCache(redis) - split_parser = RedisSplitParser(segment_cache) - self._client = Client(RedisBroker(redis, self.some_config)) - - self._splitObjects = {} - - raw_split = { - 'name': 'test1', - 'algo': 1, - 'killed': False, - 'status': 'ACTIVE', - 'defaultTreatment': 'default', - 'seed': -1222652054, - 'orgId': None, - 'environment': None, - 'trafficTypeId': None, - 'trafficTypeName': None, - 'changeNumber': 1, - 'conditions': [{ - 'conditionType': 'WHITELIST', - 'matcherGroup': { - 'combiner': 'AND', - 'matchers': [{ - 'matcherType': 'ALL_KEYS', - 'negate': False, - 'userDefinedSegmentMatcherData': None, - 'whitelistMatcherData': None - }] - }, - 'partitions': [{ - 'treatment': 'on', - 'size': 100 - }], - 'label': 'in segment all' - }] - } - self._splitObjects['whitelist'] = split_parser.parse(raw_split, True) - - raw_split['name'] = 'test2' - raw_split['conditions'][0]['conditionType'] = 'ROLLOUT' - self._splitObjects['rollout1'] = split_parser.parse(raw_split, True) - - raw_split['name'] = 'test3' - raw_split['trafficAllocation'] = 1 - raw_split['trafficAllocationSeed'] = -1 - self._splitObjects['rollout2'] = split_parser.parse(raw_split, True) - - raw_split['name'] = 'test4' - raw_split['trafficAllocation'] = None # must be mapped as 100 - raw_split['trafficAllocationSeed'] = -1 - self._splitObjects['rollout3'] = split_parser.parse(raw_split, True) - - raw_split['name'] = 'test5' - raw_split['trafficAllocation'] = 99 - raw_split['trafficAllocationSeed'] = -1 - self._splitObjects['rollout4'] = split_parser.parse(raw_split, True) - - def testTrafficAllocation(self): - ''' - ''' - treatment1, label1 = self._client._evaluator.get_treatment_for_split( - self._splitObjects['whitelist'], 'testKey', None - ) - self.assertEqual(treatment1, 'on') - - # Make sure traffic allocation is set to 100 at construction time if a - # value is not provided. - self.assertEqual( - self._splitObjects['whitelist'].traffic_allocation, - 100 - ) - - treatment2, label1 = self._client._evaluator.get_treatment_for_split( - self._splitObjects['rollout1'], 'testKey', None - ) - self.assertEqual(treatment2, 'on') - - treatment3, label1 = self._client._evaluator.get_treatment_for_split( - self._splitObjects['rollout3'], 'testKey', None - ) - self.assertEqual(treatment3, 'on') - - self.patch_object(Splitter, 'get_bucket', return_value=1) - treatment4, label1 = self._client._evaluator.get_treatment_for_split( - self._splitObjects['rollout2'], 'testKey', None - ) - self.assertEqual(treatment4, 'on') - - self.patch_object(Splitter, 'get_bucket', return_value=100) - treatment5, label1 = self._client._evaluator.get_treatment_for_split( - self._splitObjects['rollout4'], 'testKey', None - ) - self.assertEqual(treatment5, 'default') diff --git a/splitio/tests/test_splitters.py b/splitio/tests/test_splitters.py deleted file mode 100644 index 06426f7d..00000000 --- a/splitio/tests/test_splitters.py +++ /dev/null @@ -1,224 +0,0 @@ -"""Unit tests for the matchers module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from collections import Counter -from math import sqrt -from json import loads -from os.path import join, dirname -from random import randint -from unittest import TestCase, skip - -from splitio.splits import Partition, HashAlgorithm -from splitio.splitters import Splitter -from splitio.treatments import CONTROL -from splitio.hashfns import _HASH_ALGORITHMS -from splitio.tests.utils import MockUtilsMixin, random_alphanumeric_string -import io - -class SplitterGetTreatmentTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_key = mock.MagicMock() - self.some_seed = mock.MagicMock() - self.some_partitions = [mock.MagicMock(), mock.MagicMock()] - self.splitter = Splitter() -# self.hash_key_mock = self.patch_object(self.splitter, 'hash_key') - self.get_bucket_mock = self.patch_object(self.splitter, 'get_bucket') - self.get_treatment_for_bucket_mock = self.patch_object(self.splitter, - 'get_treatment_for_bucket') - - def test_get_treatment_returns_control_if_partitions_is_none(self): - """Test that get_treatment returns the control treatment if partitions is None""" - self.assertEqual(CONTROL, self.splitter.get_treatment(self.some_key, self.some_seed, None, HashAlgorithm.LEGACY)) - - def test_get_treatment_returns_control_if_partitions_is_empty(self): - """Test that get_treatment returns the control treatment if partitions is empty""" - self.assertEqual(CONTROL, self.splitter.get_treatment(self.some_key, self.some_seed, [], HashAlgorithm.LEGACY)) - - def test_get_treatment_returns_only_partition_treatment_if_it_is_100(self): - """Test that get_treatment returns the only partition treatment if it is 100%""" - some_partition = mock.MagicMock() - some_partition.size = 100 - self.assertEqual(some_partition.treatment, self.splitter.get_treatment(self.some_key, - self.some_seed, - [some_partition], - HashAlgorithm.LEGACY)) - - def test_get_treatment_calls_get_treatment_for_bucket_if_more_than_1_partition(self): - """ - Test that get_treatment calls get_treatment_for_bucket if there is more than one - partition - """ - self.splitter.get_treatment(self.some_key, self.some_seed, self.some_partitions, HashAlgorithm.LEGACY) - self.get_treatment_for_bucket_mock.assert_called_once_with( - self.get_bucket_mock.return_value, self.some_partitions) - - def test_get_treatment_returns_get_treatment_for_bucket_result_if_more_than_1_partition(self): - """ - Test that get_treatment returns the result of callling get_treatment_for_bucket if there - is more than one partition - """ - self.assertEqual( - self.get_treatment_for_bucket_mock.return_value, self.splitter.get_treatment( - self.some_key, self.some_seed, self.some_partitions, HashAlgorithm.LEGACY)) - - def test_get_treatment_calls_hash_key_if_more_than_1_partition(self): - """ - Test that get_treatment calls hash_key if there is more than one partition - """ - self.splitter.get_treatment(self.some_key, self.some_seed, self.some_partitions, HashAlgorithm.LEGACY) -# self.hash_key_mock.assert_called_once_with(self.some_key, self.some_seed) - - def test_get_treatment_calls_get_bucket_if_more_than_1_partition(self): - """ - Test that get_treatment calls get_bucket if there is more than one partition - """ - self.splitter.get_treatment(self.some_key, self.some_seed, self.some_partitions, HashAlgorithm.LEGACY) -# self.get_bucket_mock.assert_called_once_with(self.hash_key_mock.return_value) - - -class SplitterGetTreatmentForBucket(TestCase): - def setUp(self): - self.some_partitions = [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - self.some_partitions[0].size = 10 - self.some_partitions[1].size = 20 - self.some_partitions[2].size = 30 - self.splitter = Splitter() - - def test_returns_control_if_bucket_is_not_covered(self): - """ - Tests that get_treatment_for_bucket returns CONTROL if the bucket is over the segments - covered by the partition - """ - self.assertEqual(CONTROL, self.splitter.get_treatment_for_bucket(100, self.some_partitions)) - - def test_returns_treatment_of_partition_that_has_bucket(self): - """ - Tests that get_treatment_for_bucket returns the treatment of the partition that covers the - bucket. - """ - self.assertEqual(self.some_partitions[0].treatment, - self.splitter.get_treatment_for_bucket(1, self.some_partitions)) - self.assertEqual(self.some_partitions[1].treatment, - self.splitter.get_treatment_for_bucket(15, self.some_partitions)) - self.assertEqual(self.some_partitions[2].treatment, - self.splitter.get_treatment_for_bucket(33, self.some_partitions)) - - -class SplitterHashKeyTests(TestCase): - def setUp(self): - self.splitter = Splitter() - - def test_with_sample_data(self): - """ - Tests basic hash against expected values using alphanumeric values - """ - hashfn = _HASH_ALGORITHMS[HashAlgorithm.LEGACY] - with open(join(dirname(__file__), 'sample-data.jsonl')) as f: - for line in map(loads, f): - seed, key, hash_, bucket = line - self.assertEqual(int(hash_), hashfn(key, int(seed))) - @skip - def test_with_non_alpha_numeric_sample_data(self): - """ - Tests basic hash against expected values using non alphanumeric values - """ - hashfn = _HASH_ALGORITHMS[HashAlgorithm.LEGACY] - with io.open(join(dirname(__file__), 'sample-data-non-alpha-numeric.jsonl'), 'r', encoding='utf-8') as f: - for line in map(loads, f): - seed, key, hash_, bucket = line - self.assertEqual(int(hash_), hashfn(key, int(seed))) - - def test_murmur_with_sample_data(self): - """ - Tests murmur32 hash against expected values using alphanumeric values - """ - hashfn = _HASH_ALGORITHMS[HashAlgorithm.MURMUR] - with open(join(dirname(__file__), 'murmur3-sample-data-v2.csv')) as f: - for line in f: - seed, key, hash_, bucket = line.split(',') - self.assertEqual(int(hash_), hashfn(key, int(seed))) - - def test_murmur_with_non_alpha_numeric_sample_data(self): - """ - Tests murmur32 hash against expected values using non alphanumeric values - """ - hashfn = _HASH_ALGORITHMS[HashAlgorithm.MURMUR] - with io.open(join(dirname(__file__), 'murmur3-sample-data-non-alpha-numeric-v2.csv'), 'r', encoding='utf-8') as f: - for line in f: - seed, key, hash_, bucket = line.split(',') - self.assertEqual(int(hash_), hashfn(key, int(seed))) - - def test_murmur_with_custom_uuids(self): - """ - Tests murmur32 hash against expected values using non alphanumeric values - """ - hashfn = _HASH_ALGORITHMS[HashAlgorithm.MURMUR] - with io.open(join(dirname(__file__), 'murmur3-custom-uuids.csv'), 'r', encoding='utf-8') as f: - for line in f: - seed, key, hash_, bucket = line.split(',') - self.assertEqual(int(hash_), hashfn(key, int(seed))) - - -class SplitterGetBucketUnitTests(TestCase): - def setUp(self): - self.splitter = Splitter() - - def test_with_sample_data(self): - """ - Tests hash_key against expected values using alphanumeric values - """ - with open(join(dirname(__file__), 'sample-data.jsonl')) as f: - for line in map(loads, f): - seed, key, hash_, bucket = line - self.assertEqual( - int(bucket), - self.splitter.get_bucket(key, seed, HashAlgorithm.LEGACY) - ) - - # This test is being skipped because apparently LEGACY hash for - # non-alphanumeric keys isn't working properly. - # TODO: Discuss with @sarrubia whether we should raise ticket for this. - @skip - def test_with_non_alpha_numeric_sample_data(self): - """ - Tests hash_key against expected values using non alphanumeric values - """ - with open(join(dirname(__file__), 'sample-data-non-alpha-numeric.jsonl')) as f: - for line in map(loads, f): - seed, key, hash_, bucket = line - self.assertEqual( - int(bucket), - self.splitter.get_bucket(key, seed, HashAlgorithm.LEGACY) - ) - - -@skip -class SplitterGetTreatmentDistributionTests(TestCase): - def setUp(self): - self.splitter = Splitter() - - def test_1_percent_treatments_evenly_distributed(self): - """Test that get_treatment distributes treatments according to partitions""" - seed = randint(-2147483649, 2147483648) - partitions = [Partition(mock.MagicMock(), 1) for _ in range(100)] - n = 100000 - p = 0.01 - - treatments = [self.splitter.get_treatment(random_alphanumeric_string(randint(16, 32)), - seed, partitions, HashAlgorithm.LEGACY) for _ in range(n)] - counter = Counter(treatments) - - mean = n * p - stddev = sqrt(mean * (1 - p)) - - count_min = int(mean - 4 * stddev) - count_max = int(mean + 4 * stddev) - - for count in counter.values(): - self.assertTrue(count_min <= count <= count_max) diff --git a/splitio/tests/test_tasks.py b/splitio/tests/test_tasks.py deleted file mode 100644 index 1d662695..00000000 --- a/splitio/tests/test_tasks.py +++ /dev/null @@ -1,370 +0,0 @@ -"""Unit tests for the tasks module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from unittest import TestCase - -from splitio.tests.utils import MockUtilsMixin - -from splitio.tasks import (report_impressions, report_metrics, update_segments, update_segment,\ - update_splits) - - -class ReportImpressionsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.build_impressions_data_mock = self.patch( - 'splitio.tasks.build_impressions_data') - self.some_impressions_cache = mock.MagicMock() - self.some_impressions_cache.is_enabled.return_value = True - self.some_api_sdk = mock.MagicMock() - - def test_calls_is_enabled(self): - """Test that report_impressions call is_enabled""" - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_impressions_cache.is_enabled.assert_called_once_with() - - def test_doesnt_call_fetch_all_and_clear_if_disabled(self): - """Test that report_impressions doesn't call fetch_all_and_clear if the cache is disabled""" - self.some_impressions_cache.is_enabled.return_value = False - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_impressions_cache.fetch_all_and_clear.assert_not_called() - - def test_doesnt_call_test_impressions_if_disabled(self): - """Test that report_impressions doesn't call test_impressions if the cache is disabled""" - self.some_impressions_cache.is_enabled.return_value = False - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_api_sdk.test_impressions.assert_not_called() - - def test_calls_fetch_all_and_clear(self): - """Test that report_impressions calls fetch_all_and_clear""" - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_impressions_cache.fetch_all_and_clear.assert_called_once_with() - - def test_calls_build_impressions_data(self): - """Test that report_impressions calls build_impressions_data_mock""" - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.build_impressions_data_mock.assert_called_once_with( - self.some_impressions_cache.fetch_all_and_clear.return_value) - - def test_doesnt_call_test_impressions_if_data_is_empty(self): - """Test that report_impressions doesn't call test_impressions if build_impressions_data - returns an empty list""" - self.build_impressions_data_mock.return_value = [] - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_api_sdk.test_impressions.assert_not_called() - - def test_calls_test_impressions(self): - """Test that report_impressions calls test_impression with the result of - build_impressions_data""" - self.build_impressions_data_mock.return_value = [mock.MagicMock()] - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_api_sdk.test_impressions.assert_called_once_with( - self.build_impressions_data_mock.return_value) - - def test_cache_disabled_if_exception_is_raised(self): - """Test that report_impressions disables the cache if an exception is raised""" - self.build_impressions_data_mock.side_effect = Exception() - report_impressions(self.some_impressions_cache, self.some_api_sdk) - self.some_impressions_cache.disable.assert_called_once_with() - - -class ReportMetricsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_metrics_cache = mock.MagicMock() - self.some_metrics_cache.is_enabled.return_value = True - self.some_api_sdk = mock.MagicMock() - - def test_doenst_call_fetch_all_and_clear_if_disabled(self): - """Test that report_metrics doesn't call fetch_all_and_clear if the cache is disabled""" - self.some_metrics_cache.is_enabled.return_value = False - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_metrics_cache.fetch_all_and_clear.assert_not_called() - - def test_doesnt_call_metrics_times_if_disabled(self): - """Test that report_metrics doesn't call metrics_times if the cache is disabled""" - self.some_metrics_cache.is_enabled.return_value = False - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_times.assert_not_called() - - def test_doesnt_call_metrics_counters_if_disabled(self): - """Test that report_metrics doesn't call metrics_counters if the cache is disabled""" - self.some_metrics_cache.is_enabled.return_value = False - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_counters.assert_not_called() - - def test_doesnt_call_metrics_gauge_if_disabled(self): - """Test that report_metrics doesn't call metrics_gauge if the cache is disabled""" - self.some_metrics_cache.is_enabled.return_value = False - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_gauge.assert_not_called() - - def test_doesnt_call_metrics_times_if_time_metrics_is_empty(self): - """Test that report_metrics doesn't call metrics_times if time metrics are empty""" - self.some_metrics_cache.fetch_all_and_clear.return_value = {'time': [], - 'count': mock.MagicMock(), - 'gauge': mock.MagicMock()} - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_times.assert_not_called() - - def test_calls_metrics_times(self): - """Test that report_metrics calls metrics_times if time metrics are not empty""" - self.some_metrics_cache.fetch_all_and_clear.return_value = {'count': mock.MagicMock(), - 'gauge': mock.MagicMock()} - - self.some_metrics_cache.fetch_all_times_and_clear.return_value = [mock.MagicMock()] - - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_times.assert_called_once_with( - self.some_metrics_cache.fetch_all_times_and_clear.return_value) - - def test_doesnt_call_metrics_counters_if_counter_metrics_is_empty(self): - """Test that report_metrics doesn't call metrics_counters if counter metrics are empty""" - self.some_metrics_cache.fetch_all_and_clear.return_value = {'time': mock.MagicMock(), - 'count': [], - 'gauge': mock.MagicMock()} - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_counters.assert_not_called() - - def test_calls_metrics_counters(self): - """Test that report_metrics calls metrics_counters if counter metrics are not empty""" - self.some_metrics_cache.fetch_all_and_clear.return_value = {'time': mock.MagicMock(), - 'count': [mock.MagicMock()], - 'gauge': mock.MagicMock()} - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_counters.assert_called_once_with( - self.some_metrics_cache.fetch_all_and_clear.return_value['count']) - - def test_doesnt_call_metrics_gauge_if_gauge_metrics_is_empty(self): - """Test that report_metrics doesn't call metrics_gauge if counter metrics are empty""" - self.some_metrics_cache.fetch_all_and_clear.return_value = {'time': mock.MagicMock(), - 'count': mock.MagicMock(), - 'gauge': []} - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_gauge.assert_not_called() - - def test_calls_metrics_gauge(self): - """Test that report_metrics calls metrics_gauge if gauge metrics are not empty""" - self.some_metrics_cache.fetch_all_and_clear.return_value = {'time': mock.MagicMock(), - 'count': mock.MagicMock(), - 'gauge': [mock.MagicMock()]} - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_api_sdk.metrics_gauge.assert_called_once_with( - self.some_metrics_cache.fetch_all_and_clear.return_value['gauge']) - - def test_disables_cache_if_exception_is_raised(self): - """Test that report_metrics disables cache if exception is raised""" - self.some_metrics_cache.fetch_all_and_clear.side_effect = Exception() - report_metrics(self.some_metrics_cache, self.some_api_sdk) - self.some_metrics_cache.disable.assert_called_once_with() - - -class UpdateSegmentsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_segments_cache = mock.MagicMock() - self.some_segments_cache.is_enabled.return_value = True - self.some_segment_change_fetcher = mock.MagicMock() - self.update_segment_mock = self.patch('splitio.tasks.update_segment') - - def test_doesnt_call_update_segment_if_cache_is_disabled(self): - """Test that update_segments doesn't call update_segment if the segment cache is disabled""" - self.some_segments_cache.is_enabled.return_value = False - update_segments(self.some_segments_cache, self.some_segment_change_fetcher) - self.update_segment_mock.assert_not_called() - - def test_doesnt_call_update_segment_if_there_are_no_registered_segments(self): - """Test that update_segments doesn't call update_segment there are no registered segments""" - self.some_segments_cache.get_registered_segments.return_value = [] - update_segments(self.some_segments_cache, self.some_segment_change_fetcher) - self.update_segment_mock.assert_not_called() - - def test_calls_update_segment_for_each_registered_segment(self): - """Test that update_segments calls update_segment on each registered segment""" - some_segment_name = mock.MagicMock() - some_other_segment_name = mock.MagicMock() - self.some_segments_cache.get_registered_segments.return_value = [some_segment_name, - some_other_segment_name] - update_segments(self.some_segments_cache, self.some_segment_change_fetcher) - self.assertListEqual([mock.call(self.some_segments_cache, some_segment_name, - self.some_segment_change_fetcher), - mock.call(self.some_segments_cache, some_other_segment_name, - self.some_segment_change_fetcher)], - self.update_segment_mock.call_args_list) - - def test_disables_cache_if_update_segment_raises_exception(self): - """Test that update_segments disables segment_cache if update segment raises an exception""" - self.update_segment_mock.side_effect = Exception() - self.some_segments_cache.get_registered_segments.return_value = [mock.MagicMock()] - update_segments(self.some_segments_cache, self.some_segment_change_fetcher) - self.some_segments_cache.disable.assert_called_once_with() - - -class UpdateSegmentTests(TestCase): - def setUp(self): - self.some_segment_cache = mock.MagicMock() - self.some_segment_cache.get_change_number.return_value = -1 - self.some_segment_name = mock.MagicMock() - self.some_segment_change_fetcher = mock.MagicMock() - self.some_segment_change_fetcher.fetch.side_effect = [ # Two updates - { - 'name': 'some_segment_name', - 'added': ['user_id_6'], - 'removed': ['user_id_1', 'user_id_2'], - 'since': -1, - 'till': 1 - }, - { - 'name': 'some_segment_name', - 'added': ['user_id_7'], - 'removed': ['user_id_4'], - 'since': 1, - 'till': 2 - }, - { - 'name': 'some_segment_name', - 'added': [], - 'removed': [], - 'since': 2, - 'till': 2 - } - ] - - def test_calls_get_change_number(self): - """Test update_segment calls get_change_number on the segment cache""" - update_segment(self.some_segment_cache, self.some_segment_name, - self.some_segment_change_fetcher) - self.some_segment_cache.get_change_number.assert_called_once_with(self.some_segment_name) - - def test_calls_segment_change_fetcher_fetch(self): - """Test that update_segment calls segment_change_fetcher's fetch until change numbers - match""" - update_segment(self.some_segment_cache, self.some_segment_name, - self.some_segment_change_fetcher) - self.assertListEqual([mock.call(self.some_segment_name, -1), - mock.call(self.some_segment_name, 1), - mock.call(self.some_segment_name, 2)], - self.some_segment_change_fetcher.fetch.call_args_list) - - def test_calls_remove_keys_from_segment_for_all_removed_keys(self): - """Test update_segment calls remove_keys_from_segment for keys removed on each update""" - update_segment(self.some_segment_cache, self.some_segment_name, - self.some_segment_change_fetcher) - self.assertListEqual([mock.call(self.some_segment_name, ['user_id_1', 'user_id_2']), - mock.call(self.some_segment_name, ['user_id_4'])], - self.some_segment_cache.remove_keys_from_segment.call_args_list) - - def test_calls_add_keys_to_segment_for_all_added_keys(self): - """Test update_segment calls add_keys_to_segment for keys added on each update""" - update_segment(self.some_segment_cache, self.some_segment_name, - self.some_segment_change_fetcher) - self.assertListEqual([mock.call(self.some_segment_name, 1), - mock.call(self.some_segment_name, 2)], - self.some_segment_cache.set_change_number.call_args_list) - - def test_calls_set_change_number_for_updates(self): - """Test update_segment calls set_change_number on each update""" - update_segment(self.some_segment_cache, self.some_segment_name, - self.some_segment_change_fetcher) - self.assertListEqual([mock.call(self.some_segment_name, ['user_id_6']), - mock.call(self.some_segment_name, ['user_id_7'])], - self.some_segment_cache.add_keys_to_segment.call_args_list) - - -class UpdateSplitsTests(TestCase, MockUtilsMixin): - def setUp(self): - self.some_split_cache = mock.MagicMock() - self.some_split_cache.get_change_number.return_value = -1 - self.some_split_parser = mock.MagicMock() - self.parse_side_effect = ['PEPE',mock.MagicMock(), mock.MagicMock(), mock.MagicMock()] - self.some_split_parser.parse.side_effect = self.parse_side_effect - self.some_split_change_fetcher = mock.MagicMock() - self.some_split_change_fetcher.fetch.side_effect = [ - { - 'till': 1, - 'splits': [ - { - 'status': 'ACTIVE', - 'name': 'some_split' - }, - { - 'status': 'ACTIVE', - 'name': 'some_other_split' - } - ] - }, - { - 'till': 2, - 'splits': [ - { - 'status': 'ACTIVE', - 'name': 'some_split' - }, - { - 'status': 'ARCHIVED', - 'name': 'some_other_split' - } - ] - }, - { - 'till': 2, - 'splits': [] - } - ] - - def test_calls_get_change_number(self): - """Test that update_splits calls get_change_number on the split cache""" - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - self.some_split_cache.get_change_number.assert_called_once_with() - - def test_calls_split_change_fetcher_fetch(self): - """Test that update_splits calls split_change_fetcher's fetch method until change numbers - match""" - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - self.assertListEqual([mock.call(-1), mock.call(1), mock.call(2)], - self.some_split_change_fetcher.fetch.call_args_list) - - def test_calls_split_parser_parse(self): - """Test that update_split calls split_parser's parse method on all active splits on each - update""" - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - self.assertListEqual([mock.call({'status': 'ACTIVE', 'name': 'some_split'}), - mock.call({'status': 'ACTIVE', 'name': 'some_other_split'}), - mock.call({'status': 'ACTIVE', 'name': 'some_split'})], - self.some_split_parser.parse.call_args_list) - - def test_calls_remove_split(self): - """Test that update_split calls split_cache's remove_split method on archived splits""" - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - self.some_split_cache.remove_split.assert_called_once_with('some_other_split') - - def test_calls_add_split(self): - """Test that update_split calls split_cache's add_split method on active splits""" - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - - """self.assertListEqual([mock.call('some_split', self.parse_side_effect[0]), - mock.call('some_other_split', self.parse_side_effect[1]), - mock.call('some_split', self.parse_side_effect[2])], - self.some_split_cache.add_split.call_args_list)""" - - self.assertListEqual([mock.call('some_split', {'status': 'ACTIVE', 'name': 'some_split'}), - mock.call('some_other_split', {'status': 'ACTIVE', 'name': 'some_other_split'}), - mock.call('some_split', {'status': 'ACTIVE', 'name': 'some_split'})], - self.some_split_cache.add_split.call_args_list) - - - - def test_calls_set_change_number(self): - """Test that update_split calls set_change_number on every update""" - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - self.assertListEqual([mock.call(1), mock.call(2)], - self.some_split_cache.set_change_number.call_args_list) - - def test_disables_cache_on_exception(self): - """Test that update_split calls disable on the split_cache when an exception is raised""" - self.some_split_change_fetcher.fetch.side_effect = Exception() - update_splits(self.some_split_cache, self.some_split_change_fetcher, self.some_split_parser) - self.some_split_cache.disable.assert_called_once_with() diff --git a/splitio/tests/test_transformers.py b/splitio/tests/test_transformers.py deleted file mode 100644 index aecc85da..00000000 --- a/splitio/tests/test_transformers.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Unit tests for the transformers module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -from builtins import int -from unittest import TestCase - -import arrow - -from splitio.transformers import (AsNumberTransformMixin, AsDateHourMinuteTimestampTransformMixin, - AsDateTimestampTransformMixin) -from splitio.tests.utils import MockUtilsMixin - - -class AsNumberTransformMixinTests(TestCase, MockUtilsMixin): - def setUp(self): - self.transformer = AsNumberTransformMixin() - - def test_transform_key_works_with_small_int(self): - """Tests that transform_key works with small integers""" - transformed = self.transformer.transform_key(12345) - self.assertIsInstance(transformed, int) - self.assertEqual(12345, transformed) - - def test_transform_key_works_with_large_int(self): - """Tests that transform_key works with large integers""" - transformed = self.transformer.transform_key(9223372036854775808) - self.assertIsInstance(transformed, int) - self.assertEqual(9223372036854775808, transformed) - - def test_transform_key_works_with_small_str(self): - """Tests that transform_key works with strings with small integers""" - transformed = self.transformer.transform_key('12345') - self.assertIsInstance(transformed, int) - self.assertEqual(12345, transformed) - - def test_transform_key_works_with_large_str(self): - """Tests that transform_key works with strings with large integers""" - transformed = self.transformer.transform_key('9223372036854775808') - self.assertIsInstance(transformed, int) - self.assertEqual(9223372036854775808, transformed) - - def test_transform_key_returns_none_with_invalid_number(self): - """Tests that transform_key returns none with strings with invalid integers""" - self.assertIsNone(self.transformer.transform_key('foobar')) - - def test_transform_condition_parameter_returns_transform_key_result(self): - """Tests that transform_condition_parameter returns the result of calling transform_key""" - transform_key_mock = self.patch_object(self.transformer, '_transform_key') - self.assertEqual(transform_key_mock.return_value, - self.transformer.transform_condition_parameter('foobar')) - - -class AsDateHourMinuteTimestampTransformMixinTests(TestCase): - def setUp(self): - self.transformer = AsDateHourMinuteTimestampTransformMixin() - - def test_transform_key_truncates_second_millisecond(self): - """Tests that transform_key truncates seconds and milliseconds""" - value = arrow.get(2016, 5, 1, 16, 35, 28, 19).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(1462120500000, transformed) - - def test_transform_key_works_on_epoch_lower_limit(self): - """Tests that transform_key works when supplied with the epoch timestamp lower limit""" - value = arrow.get(1970, 1, 1, 0, 0, 0, 0).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(0, transformed) - - def test_transform_key_works_on_epoch_upper_limit(self): - """Tests that transform_key works when supplied with the epoch timestamp upper limit""" - value = arrow.get(2038, 1, 19, 3, 14, 8, 0).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(2147483640000, transformed) - - def test_transform_key_works_under_epoch_lower_limit(self): - """Tests that transform_key works when supplied with a value under the epoch timestamp - lower limit""" - value = arrow.get(1969, 1, 1, 20, 16, 13, 5).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(-31463040000, transformed) - - def test_transform_key_works_over_epoch_upper_limit(self): - """Tests that transform_key works when supplied with a value over the epoch timestamp upper - limit""" - value = arrow.get(2038, 2, 19, 3, 14, 8, 9).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(2150162040000, transformed) - - def test_transform_key_returns_none_with_invalid_numer(self): - """Tests that transform_key returns None when given an invalid number""" - self.assertIsNone(self.transformer.transform_key('foobar')) - - def test_transform_condition_parameter_truncates_second_millisecond(self): - """Tests that transform truncates seconds and milliseconds""" - value = arrow.get(2016, 5, 1, 16, 35, 28, 19).timestamp * 1000 - - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(1462120500000, transformed) - - def test_transform_condition_parameter_works_on_epoch_lower_limit(self): - """Tests that transform_condition_parameter works when supplied with the epoch timestamp - lower limit""" - value = arrow.get(1970, 1, 1, 0, 0, 0, 0).timestamp * 1000 - - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(0, transformed) - - def test_transform_condition_parameter_works_on_epoch_upper_limit(self): - """Tests that transform_condition_parameter works when supplied with the epoch timestamp - upper limit""" - value = arrow.get(2038, 1, 19, 3, 14, 8, 0).timestamp * 1000 - - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(2147483640000, transformed) - - def test_transform_condition_parameter_works_under_epoch_lower_limit(self): - """Tests that transform_condition_parameter works when supplied with a value under the - epoch timestamp lower limit""" - value = arrow.get(1969, 1, 1, 20, 16, 13, 5).timestamp * 1000 - - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(-31463040000, transformed) - - def test_transform_condition_parameter_works_over_epoch_upper_limit(self): - """Tests that transform_condition_parameter works when supplied with a value over the epoch - timestamp upper limit""" - value = arrow.get(2038, 2, 19, 3, 14, 8, 9).timestamp * 1000 - - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(2150162040000, transformed) - - def test_transform_condition_parameter_returns_none_with_invalid_numer(self): - """Tests that transform_key returns None when given an invalid number""" - self.assertIsNone(self.transformer.transform_condition_parameter('foobar')) - - -class AsDateTimestampTransformMixinTests(TestCase): - def setUp(self): - self.transformer = AsDateTimestampTransformMixin() - - def test_transform_key_truncates_second_millisecond(self): - """Tests that transform_key truncates seconds and milliseconds""" - value = arrow.get(2016, 5, 1, 16, 35, 28, 19).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(1462060800000, transformed) - - def test_transform_key_works_on_epoch_lower_limit(self): - """Tests that transform_key works when supplied with the epoch timestamp lower limit""" - value = arrow.get(1970, 1, 1, 0, 0, 0, 0).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(0, transformed) - - def test_transform_key_works_on_epoch_upper_limit(self): - """Tests that transform_key works when supplied with the epoch timestamp upper limit""" - value = arrow.get(2038, 1, 19, 3, 14, 8, 0).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(2147472000000, transformed) - - def test_transform_key_works_under_epoch_lower_limit(self): - """Tests that transform_key works when supplied with a value under the epoch timestamp - lower limit""" - value = arrow.get(1969, 1, 1, 20, 12, 34, 8).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(-31536000000, transformed) - - def test_transform_key_works_over_epoch_upper_limit(self): - """Tests that transform_key works when supplied with a value over the epoch timestamp upper - limit""" - value = arrow.get(2038, 2, 19, 3, 14, 8, 15).timestamp - - transformed = self.transformer.transform_key(value) - self.assertEqual(2150150400000, transformed) - - def test_transform_key_returns_none_with_invalid_numer(self): - """Tests that transform_key returns None when given an invalid number""" - self.assertIsNone(self.transformer.transform_key('foobar')) - - def test_transform_condition_parameter_truncates_second_millisecond(self): - """Tests that transform_key truncates seconds and milliseconds""" - value = arrow.get(2016, 5, 1, 16, 35, 28, 19).timestamp * 1000 - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(1462060800000, transformed) - - def test_transform_condition_parameter_works_on_epoch_lower_limit(self): - """Tests that transform_key works when supplied with the epoch timestamp lower limit""" - value = arrow.get(1970, 1, 1, 0, 0, 0, 0).timestamp * 1000 - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(0, transformed) - - def test_transform_condition_parameter_works_on_epoch_upper_limit(self): - """Tests that transform_key works when supplied with the epoch timestamp upper limit""" - value = arrow.get(2038, 1, 19, 3, 14, 8, 0).timestamp * 1000 - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(2147472000000, transformed) - - def test_transform_condition_parameter_works_under_epoch_lower_limit(self): - """Tests that transform_key works when supplied with a value under the epoch timestamp - lower limit""" - value = arrow.get(1969, 1, 1, 20, 12, 34, 8).timestamp * 1000 - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(-31536000000, transformed) - - def test_transform_condition_parameter_works_over_epoch_upper_limit(self): - """Tests that transform_key works when supplied with a value over the epoch timestamp upper - limit""" - value = arrow.get(2038, 2, 19, 3, 14, 8, 15).timestamp * 1000 - transformed = self.transformer.transform_condition_parameter(value) - self.assertEqual(2150150400000, transformed) - - def test_transform_condition_parameter_returns_none_with_invalid_numer(self): - """Tests that transform_key returns None when given an invalid number""" - self.assertIsNone(self.transformer.transform_condition_parameter('foobar')) diff --git a/splitio/tests/test_uwsgi.py b/splitio/tests/test_uwsgi.py deleted file mode 100644 index 977e99dd..00000000 --- a/splitio/tests/test_uwsgi.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Unit tests for the tasks module""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -try: - from jsonpickle import decode, encode -except ImportError: - def missing_jsonpickle_dependencies(*args, **kwargs): - raise NotImplementedError('Missing jsonpickle support dependencies.') - decode = encode = missing_jsonpickle_dependencies - - -from unittest import TestCase - -from splitio.uwsgi import UWSGISplitCache, UWSGISegmentCache, UWSGIImpressionsCache, UWSGIMetricsCache, get_uwsgi -from splitio.impressions import Impression -from splitio.metrics import BUCKETS - - -class MockUWSGILock(object): - def __enter__(self, *args, **kwargs): - pass - - def __exit__(self, *args, **kwargs): - pass - -def mock_lock(*args, **kwargs): - pass - - - -class UWSGICacheEmulatorTests(TestCase): - def setUp(self): - self.uwsgi = get_uwsgi(emulator=True) - self.cache_namespace = 'splitio' - self.cache_key = 'some_key' - self.some_value = 'some_string_value' - self.some_other_value = 'some_other_string_value' - - def test_set_and_get(self): - self.uwsgi.cache_set(self.cache_key, str(self.some_value), 0, self.cache_namespace) - data = self.uwsgi.cache_get(self.cache_key, self.cache_namespace) - self.assertEqual(self.some_value, data) - - def test_update_and_get(self): - self.uwsgi.cache_update(self.cache_key, str(self.some_other_value), 0, self.cache_namespace) - data = self.uwsgi.cache_get(self.cache_key, self.cache_namespace) - self.assertEqual(self.some_other_value, data) - - def test_set_exists_del(self): - self.uwsgi.cache_set(self.cache_key, str(self.some_value), 0, self.cache_namespace) - self.assertTrue(self.uwsgi.cache_exists(self.cache_key, self.cache_namespace)) - self.uwsgi.cache_del(self.cache_key, self.cache_namespace) - self.assertFalse(self.uwsgi.cache_exists(self.cache_key, self.cache_namespace)) - - -class UWSGISplitCacheTests(TestCase): - def setUp(self): - mock.patch('splitio.uwsgi.UWSGILock', autospec=True) - self.uwsgi_adapter = get_uwsgi(emulator=True) - self.split_cache = UWSGISplitCache(self.uwsgi_adapter) - self.split_json = """ - { - "orgId":null, - "environment":null, - "trafficTypeId":null, - "trafficTypeName":"user", - "name":"test_multi_condition", - "seed":-1329591480, - "status":"ACTIVE", - "killed":false, - "defaultTreatment":"off", - "changeNumber":1325599980, - "conditions":[ - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_on" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":null, - "matcherType":"WHITELIST", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":{ - "whitelist":[ - "fake_id_off" - ] - }, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"off", - "size":100 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":null - }, - "matcherType":"IN_SEGMENT", - "negate":false, - "userDefinedSegmentMatcherData":{ - "segmentName":"demo" - }, - "whitelistMatcherData":null, - "unaryNumericMatcherData":null, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - }, - { - "matcherGroup":{ - "combiner":"AND", - "matchers":[ - { - "keySelector":{ - "trafficType":"user", - "attribute":"some_attribute" - }, - "matcherType":"EQUAL_TO", - "negate":false, - "userDefinedSegmentMatcherData":null, - "whitelistMatcherData":null, - "unaryNumericMatcherData":{ - "dataType":"NUMBER", - "value":42 - }, - "betweenMatcherData":null - } - ] - }, - "partitions":[ - { - "treatment":"on", - "size":100 - }, - { - "treatment":"off", - "size":0 - } - ] - } - ] - }""" - - def test_add_get_split(self): - self.split_cache.add_split('test_multi_condition', decode(self.split_json)) - split = self.split_cache.get_split('test_multi_condition') - self.assertEqual(split.name, 'test_multi_condition') - - def test_remove_split(self): - self.split_cache.add_split('test_multi_condition', decode(self.split_json)) - self.split_cache.remove_split('test_multi_condition') - split = self.split_cache.get_split('test_multi_condition') - self.assertIsNone(split) - - # Because accessing the cache and blocking is costly in uwsgi mode, - # the split list is maintained by the sync task in order to minimize locking. - # So the add_split call doesn't update it. This is an inconsistency with the rest of - # the storages but yields for better performance. -# def test_get_split_keys(self): -# self.split_cache.add_split('test_multi_condition', decode(self.split_json)) -# current_keys = self.split_cache.get_splits_keys() -# self.assertIn('test_multi_condition', current_keys) - - @mock.patch('splitio.uwsgi.UWSGILock') - def test_get_splits(self, lock_mock): - lock_mock.return_value = MockUWSGILock() - self.split_cache.add_split('test_multi_condition', decode(self.split_json)) - self.split_cache.update_split_list(['test_multi_condition'], []) - current_splits = self.split_cache.get_splits() - self.assertEqual(self.split_cache.get_split('test_multi_condition').name, current_splits[0].name) - - def test_change_number(self): - change_number = 1325599980 - self.split_cache.set_change_number(change_number) - self.assertEqual(change_number, self.split_cache.get_change_number()) - - -class UWSGISegmentCacheTests(TestCase): - def setUp(self): - self.some_segment_name = mock.MagicMock() - self.some_segment_name_str = 'some_segment_name' - self.some_segment_keys = ['key_1', 'key_2'] - self.remove_segment_keys = ['key_1'] - self.some_key = mock.MagicMock() - self.some_change_number = mock.MagicMock() - - self.some_uwsgi = get_uwsgi(emulator=True) - self.segment_cache = UWSGISegmentCache(self.some_uwsgi) - - def test_register_segment(self): - """Test that register a segment""" - self.segment_cache.register_segment(self.some_segment_name_str) - registered_segments = self.segment_cache.get_registered_segments() - self.assertIn(self.some_segment_name_str, registered_segments) - - def test_unregister_segment(self): - self.segment_cache.unregister_segment(self.some_segment_name_str) - registered_segments = self.segment_cache.get_registered_segments() - self.assertNotIn(self.some_segment_name_str, registered_segments) - - def test_add_segment_keys(self): - self.segment_cache.register_segment(self.some_segment_name_str) - self.segment_cache.add_keys_to_segment(self.some_segment_name_str, self.some_segment_keys) - self.assertTrue(self.segment_cache.is_in_segment(self.some_segment_name_str, self.some_segment_keys[0])) - self.assertTrue(self.segment_cache.is_in_segment(self.some_segment_name_str, self.some_segment_keys[1])) - - def test_remove_segment_keys(self): - self.segment_cache.add_keys_to_segment(self.some_segment_name_str, self.some_segment_keys) - self.segment_cache.remove_keys_from_segment(self.some_segment_name_str, self.remove_segment_keys) - self.assertFalse(self.segment_cache.is_in_segment(self.some_segment_name_str, self.remove_segment_keys[0])) - self.assertTrue(self.segment_cache.is_in_segment(self.some_segment_name_str, self.some_segment_keys[1])) - - def test_change_number(self): - change_number = 1325599980 - self.segment_cache.set_change_number(self.some_segment_name_str, change_number) - self.assertEqual(change_number, self.segment_cache.get_change_number(self.some_segment_name_str)) - - -class UWSGIImpressionCacheTest(TestCase): - def setUp(self): - self._impression_1 = Impression(matching_key='matching_key', - feature_name='feature_name_1', - treatment='treatment', - label='label', - change_number='change_number', - bucketing_key='bucketing_key', - time=1325599980) - - self._impression_2 = Impression(matching_key='matching_key', - feature_name='feature_name_2', - treatment='treatment', - label='label', - change_number='change_number', - bucketing_key='bucketing_key', - time=1325599980) - - self.impression_cache = UWSGIImpressionsCache(get_uwsgi(emulator=True)) - - @mock.patch('splitio.uwsgi.UWSGILock') - def test_impression(self, lock_mock): - lock_mock.return_value = MockUWSGILock() - self.impression_cache.add_impression(self._impression_1) - self.impression_cache.add_impression(self._impression_2) - impressions = self.impression_cache.fetch_all_and_clear() - expected = {'feature_name_1': [self._impression_1], 'feature_name_2': [self._impression_2]} - self.assertEqual(expected,impressions) - - -class UWSGIMetricsCacheTest(TestCase): - def setUp(self): - self.metrics_cache = UWSGIMetricsCache(get_uwsgi(emulator=True)) - self.operation = 'sdk.getTreatment' - - def test_increment_latency(self): - self.metrics_cache.increment_latency_bucket_counter(self.operation,2) - self.assertEqual(1,self.metrics_cache.get_latency_bucket_counter(self.operation, 2)) - - def test_counter(self): - self.metrics_cache.set_count('some_counter',2) - self.metrics_cache.increment_count('some_counter') - self.assertEqual(3, self.metrics_cache.get_count('some_counter')) - - def test_gauge(self): - self.metrics_cache.set_gauge('some_gauge',123123123123) - self.assertEqual(123123123123, self.metrics_cache.get_gauge('some_gauge')) - - def test_times(self): - bucket_index = 5 - self.metrics_cache.increment_latency_bucket_counter(self.operation, bucket_index) - latencies = [0] * len(BUCKETS) - latencies[bucket_index]=1 - expected = [{'name':self.operation, 'latencies':latencies}] - self.assertEqual(expected,self.metrics_cache.fetch_all_times_and_clear()) - - def test_count_gauge(self): - self.metrics_cache.set_gauge('some_gauge', 10) - self.metrics_cache.set_count('some_count', 20) - expected = {'count': [{'name':'some_count', 'delta':20}], 'gauge': [{'name':'some_gauge', 'value': 10}]} - self.assertEqual(expected, self.metrics_cache.fetch_all_and_clear()) - diff --git a/splitio/tests/utils.py b/splitio/tests/utils.py deleted file mode 100644 index cbf619d8..00000000 --- a/splitio/tests/utils.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Unit test helpers""" -from __future__ import absolute_import, division, print_function, unicode_literals - -from sys import version_info - -try: - from unittest import mock -except ImportError: - # Python 2 - import mock - -from string import ascii_letters, digits, printable -from random import choice - - -class MockUtilsMixin(object): - """Add handy methods to reduce boilerplate on tests that patch things on initialization. - Example usage: - - class MyTest(object, MockUtilsMixin): - def setUp(self): - self.obj = ObjClass() - self.some_mock = self.patch('some.module.function') - self.some_function = self.patch_object(self.obj, 'some_method') - """ - def patch_builtin(self, name, *args, **kwargs): - return self.patch(('builtins.{}' if version_info >= (3,) else '__builtin__.{}').format( - name), *args, **kwargs) - - def patch(self, *args, **kwargs): - patcher = mock.patch(*args, **kwargs) - patched = patcher.start() - patched.patcher = patcher - self.addCleanup(patcher.stop) - return patched - - def patch_object(self, *args, **kwargs): - patcher = mock.patch.object(*args, **kwargs) - patched = patcher.start() - patched.patcher = patcher - self.addCleanup(patcher.stop) - return patched - - -def random_alphanumeric_string(size): - """ - Generates a random alphanumeric string of a given size - :param size: The size of the string - :type size: int - :return: An alphanumeric string - :rtype: str - """ - return [choice(ascii_letters + digits) for _ in range(size)] - - -def random_printable_string(size): - """ - Generates a random printable string of a given size - :param size: The size of the string - :type size: int - :return: A printable string - :rtype: str - """ - return [choice(printable) for _ in range(size)] diff --git a/splitio/transformers.py b/splitio/transformers.py deleted file mode 100644 index 8eee0501..00000000 --- a/splitio/transformers.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from builtins import int - -from arrow import Arrow - -logger = logging.getLogger(__name__) - - -class TransformMixin(object): - """Base for all transform mixins""" - def _transform_key(self, value): - """ - Transforms a value. Subclasses need to implement this method - :param value: The value to transform - :type value: any - :return: The transformed value - :rtype: any - """ - return value - - def transform_key(self, value): - """ - Transforms a value - :param value: The value to transform - :type value: any - :return: The transformed value - :rtype: any - """ - try: - return self._transform_key(value) - except: - logger.error('Error transforming value. value = %s', value) - - def _transform_condition_parameter(self, source_value): - """ - Transforms a condition parameter. Subclasses need to implement this method - :param value: The value to transform - :type value: any - :return: The transformed value - :rtype: any - """ - return self._transform_key(source_value) - - def transform_condition_parameter(self, source_value): - """ - Transforms a condition parameter - :param value: The value to transform - :type value: any - :return: The transformed value - :rtype: any - """ - try: - return self._transform_condition_parameter(source_value) - except: - logger.error('Error transforming source value. source_value = %s', source_value) - - -class AsNumberTransformMixin(TransformMixin): - """Mixin to allow transforming values to int (long)""" - def _transform_key(self, value): - """ - Transforms value to int (long in Python2) - :param value: Any value suitable to be transformed - :type value: - :return: The value transformed to int - :rtype: int - """ - if value is None: - return None - - return int(value) - - -class AsDateHourMinuteTimestampTransformMixin(TransformMixin): - """Mixin to allow truncating timestamp to the minute""" - def _transform_key(self, value): - """ - Truncates seconds and milliseconds from a long value of seconds from epoch - :param value: An int value representing the number of seconds from epoch (timestmap) - :type value: int - :return: The value truncated to minutes - :rtype: int - """ - if value is None: - return None - - return Arrow.utcfromtimestamp(value).replace(second=0, microsecond=0).timestamp * 1000 - - def _transform_condition_parameter(self, source_value): - """ - Truncates seconds and milliseconds from a long value of milliseconds from epoch - :param value: An int value representing the number of milliseconds from epoch (timestmap) - :type value: int - :return: The value truncated to minutes - :rtype: int - """ - if source_value is None: - return None - - return Arrow.utcfromtimestamp(source_value // 1000).replace(second=0, - microsecond=0).timestamp * 1000 - - -class AsDateTimestampTransformMixin(TransformMixin): - """Mixin to allow truncating timestamp to midnight""" - def _transform_key(self, value): - """ - Truncates seconds and milliseconds from a long value of seconds from epoch - :param value: An int value representing the number of seconds from epoch (timestmap) - :type value: int - :return: The value truncated to midnight of the same day - :rtype: int - """ - if value is None: - return None - - return Arrow.utcfromtimestamp(value).replace(hour=0, minute=0, second=0, - microsecond=0).timestamp * 1000 - - def _transform_condition_parameter(self, source_value): - """ - Truncates seconds and milliseconds from a long value of milliseconds from epoch - :param value: An int value representing the number of milliseconds from epoch (timestmap) - :type value: int - :return: The value truncated to midnight of the same day - :rtype: int - """ - if source_value is None: - return None - - return Arrow.utcfromtimestamp(source_value // 1000).replace(hour=0, minute=0, second=0, - microsecond=0).timestamp * 1000 diff --git a/splitio/treatments.py b/splitio/treatments.py deleted file mode 100644 index 4a5e2e58..00000000 --- a/splitio/treatments.py +++ /dev/null @@ -1,5 +0,0 @@ -"""This module contains everything related to treatments""" -from __future__ import absolute_import, division, print_function, unicode_literals - - -CONTROL = 'control' diff --git a/splitio/update_scripts/post_impressions.py b/splitio/update_scripts/post_impressions.py deleted file mode 100644 index 0827dff5..00000000 --- a/splitio/update_scripts/post_impressions.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Post Split.io impressions. - -Usage: - post_impressions - post_impressions -h | --help - post_impressions --version - -Options: - -h --help Show this screen. - --version Show version. -""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from splitio.api import api_factory -from splitio.config import SDK_VERSION, parse_config_file -from splitio.redis_support import get_redis, RedisImpressionsCache -from splitio.tasks import report_impressions - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('splitio.update_scripts.post_impressions') - - -def run(arguments): - try: - config = parse_config_file(arguments['']) - redis = get_redis(config) - impressions_cache = RedisImpressionsCache(redis) - sdk_api = api_factory(config) - report_impressions(impressions_cache, sdk_api) - except: - logger.error('Error posting impressions') - - -if __name__ == '__main__': - from docopt import docopt - arguments = docopt(__doc__, version=SDK_VERSION) - run(arguments) diff --git a/splitio/update_scripts/post_metrics.py b/splitio/update_scripts/post_metrics.py deleted file mode 100644 index 5a74bfc3..00000000 --- a/splitio/update_scripts/post_metrics.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Post Split.io metrics. - -Usage: - post_metrics - post_metrics -h | --help - post_metrics --version - -Options: - -h --help Show this screen. - --version Show version. -""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from splitio.api import api_factory -from splitio.config import SDK_VERSION, parse_config_file -from splitio.redis_support import get_redis, RedisMetricsCache -from splitio.tasks import report_metrics - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('splitio.update_scripts.post_impressions') - - -def run(arguments): - try: - config = parse_config_file(arguments['']) - redis = get_redis(config) - metrics_cache = RedisMetricsCache(redis) - sdk_api = api_factory(config) - report_metrics(metrics_cache, sdk_api) - except: - logger.error('Error posting metrics') - - -if __name__ == '__main__': - from docopt import docopt - arguments = docopt(__doc__, version=SDK_VERSION) - run(arguments) diff --git a/splitio/update_scripts/update_segments.py b/splitio/update_scripts/update_segments.py deleted file mode 100644 index 3aa9b731..00000000 --- a/splitio/update_scripts/update_segments.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Update Split.io segments. - -Usage: - update_segments - update_segments -h | --help - update_segments --version - -Options: - -h --help Show this screen. - --version Show version. -""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from splitio.api import api_factory -from splitio.config import SDK_VERSION, parse_config_file -from splitio.redis_support import get_redis, RedisSegmentCache -from splitio.segments import ApiSegmentChangeFetcher -from splitio.tasks import update_segments - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('splitio.update_scripts.update_segments') - - -def run(arguments): - try: - config = parse_config_file(arguments['']) - redis = get_redis(config) - segment_cache = RedisSegmentCache(redis) - sdk_api = api_factory(config) - segment_change_fetcher = ApiSegmentChangeFetcher(sdk_api) - update_segments(segment_cache, segment_change_fetcher) - except: - logger.error('Error updating segments') - - -if __name__ == '__main__': - from docopt import docopt - arguments = docopt(__doc__, version=SDK_VERSION) - run(arguments) diff --git a/splitio/update_scripts/update_splits.py b/splitio/update_scripts/update_splits.py deleted file mode 100644 index 765f57be..00000000 --- a/splitio/update_scripts/update_splits.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Update Split.io splits. - -Usage: - update_splits - update_splits -h | --help - update_splits --version - -Options: - -h --help Show this screen. - --version Show version. -""" -from __future__ import absolute_import, division, print_function, unicode_literals - -import logging - -from splitio.api import api_factory -from splitio.config import SDK_VERSION, parse_config_file -from splitio.redis_support import get_redis, RedisSplitCache, RedisSegmentCache, RedisSplitParser -from splitio.splits import ApiSplitChangeFetcher -from splitio.tasks import update_splits - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('splitio.update_scripts.update_splits') - - -def run(arguments): - try: - config = parse_config_file(arguments['']) - redis = get_redis(config) - split_cache = RedisSplitCache(redis) - sdk_api = api_factory(config) - split_change_fetcher = ApiSplitChangeFetcher(sdk_api) - segment_cache = RedisSegmentCache(redis) - split_parser = RedisSplitParser(segment_cache) - update_splits(split_cache, split_change_fetcher, split_parser) - except: - logger.error('Error updating splits') - - -if __name__ == '__main__': - from docopt import docopt - arguments = docopt(__doc__, version=SDK_VERSION) - run(arguments) diff --git a/splitio/utils.py b/splitio/utils.py deleted file mode 100644 index fe6cb799..00000000 --- a/splitio/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -import socket - - -def bytes_to_string(bytes, encode='utf-8'): - if type(bytes).__name__ == 'bytes': - return str(bytes, encode) - - return bytes - - -def get_ip(): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - # doesn't even have to be reachable - s.connect(('10.255.255.255', 1)) - IP = s.getsockname()[0] - except Exception: - IP = 'unknown' - finally: - s.close() - return IP - - -def get_hostname(): - ip = get_ip() - return 'unknown' if ip == 'unknown' else 'ip-' + ip.replace('.', '-') diff --git a/splitio/uwsgi.py b/splitio/uwsgi.py deleted file mode 100644 index 9b86e628..00000000 --- a/splitio/uwsgi.py +++ /dev/null @@ -1,1053 +0,0 @@ -"""A module for Split.io SDK and uwsgi compatibility - -Strong dependency of uWSGI Cache Framework -https://uwsgi-docs.readthedocs.io/en/latest/Caching.html - -Sample Command: uwsgi --http :9090 --wsgi-file mysite/wsgi.py --enable-threads --master - --cache2 name=splitio,items=5000,store=/tmp/uwsgi_cache.6 - - -Cache item size -; create a cache for images with dynamic size (images can be big, so do not waste memory) -cache2 = name=images,items=20,bitmap=1,blocks=100 - -; a cache for css (20k per-item is more than enough) -cache2 = name=stylesheets,items=30,blocksize=20000 - - -SPOOLER: -uwsgi --http :9090 --wsgi-file mysite/wsgi.py --processes 4 --threads 2 --enable-threads --master - --cache2 name=splitio,items=5000,store=/tmp/splitio.cache --spooler myspool --import myspool.spooltasks - -""" -from __future__ import absolute_import, division, print_function, unicode_literals - -try: - #uwsgi is loaded at runtime by uwsgi app. - import uwsgi -except ImportError: - def missing_uwsgi_dependencies(*args, **kwargs): - raise NotImplementedError('Missing uWSGI support dependencies.') - uwsgi = missing_uwsgi_dependencies - -try: - from jsonpickle import decode, encode -except ImportError: - def missing_jsonpickle_dependencies(*args, **kwargs): - raise NotImplementedError('Missing jsonpickle support dependencies.') - decode = encode = missing_jsonpickle_dependencies - - -import re -import logging -import time - -from itertools import groupby -from six import iteritems -from collections import defaultdict - -from splitio.cache import SegmentCache, SplitCache, ImpressionsCache, MetricsCache -from splitio.api import api_factory -from splitio.tasks import update_splits, update_segments, report_metrics, report_impressions, EventsSyncTask -from splitio.splits import Split, ApiSplitChangeFetcher, SplitParser, HashAlgorithm -from splitio.segments import Segment, ApiSegmentChangeFetcher -from splitio.matchers import UserDefinedSegmentMatcher -from splitio.utils import bytes_to_string -from splitio.impressions import Impression -from splitio.metrics import BUCKETS -from splitio.config import DEFAULT_CONFIG -from splitio.events import Event - - - -_logger = logging.getLogger(__name__) - -# Cache used for locking & signaling keys -_SPLITIO_LOCK_CACHE_NAMESPACE = 'splitio_locks' - -# Cache where split definitions are stored -_SPLITIO_SPLITS_CACHE_NAMESPACE = 'splitio_splits' - -# Cache where segments are stored -_SPLITIO_SEGMENTS_CACHE_NAMESPACE = 'splitio_segments' - -# Cache where impressions are stored -_SPLITIO_IMPRESSIONS_CACHE_NAMESPACE = 'splitio_impressions' - -# Cache where metrics are stored -_SPLITIO_METRICS_CACHE_NAMESPACE = 'splitio_metrics' - -# Cache where events are stored (1 key with lots of blocks) -_SPLITIO_EVENTS_CACHE_NAMESPACE = 'splitio_events' - -# Cache where changeNumbers are stored -_SPLITIO_CHANGE_NUMBERS = 'splitio_changeNumbers' - -# Cache with a big block size used for lists -_SPLITIO_MISC_NAMESPACE = 'splitio_misc' - - -def _get_config(user_config): - sdk_config = DEFAULT_CONFIG - sdk_config.update(user_config) - return sdk_config - - -class UWSGILock: - """Context manager to be used for locking a key in the cache.""" - - def __init__(self, key, overwrite_lock_seconds=5): - """ - Initialize a lock witht key `key` and waits up to `overwrite_lock_seconds` - before the thread is released (if it hasn't been manually unlocked). - - :param key: Key to be used. - :type key: str - - :param overwrite_lock_seconds: How many seconds to wait before force-releasing. - :type overwrite_lock_seconds: int - """ - self._key = key - self._overwrite_lock_seconds = overwrite_lock_seconds - self._adapter = get_uwsgi() - - def __enter__(self): - """Loop until the lock is manually released or timeout occurs""" - initial_time = time.time() - while True: - if not self._adapter.cache_exists(self._key, _SPLITIO_LOCK_CACHE_NAMESPACE): - self._adapter.cache_set(self._key, str('locked'), 0, _SPLITIO_LOCK_CACHE_NAMESPACE) - return - else: - if time.time() - initial_time > self._overwrite_lock_seconds: - return - time.sleep(0.3) - - def __exit__(self, *args): - """Remove lock""" - self._adapter.cache_del(self._key, _SPLITIO_LOCK_CACHE_NAMESPACE) - - -def uwsgi_update_splits(user_config): - try: - config = _get_config(user_config) - seconds = config['featuresRefreshRate'] - while True: - split_cache = UWSGISplitCache(get_uwsgi()) - - sdk_api = api_factory(config) - split_change_fetcher = ApiSplitChangeFetcher(sdk_api) - - segment_cache = UWSGISegmentCache(get_uwsgi()) - split_parser = UWSGISplitParser(segment_cache) - - added, removed = update_splits(split_cache, split_change_fetcher, split_parser) - split_cache.update_split_list(added, removed) - - time.sleep(seconds) - except: - _logger.error('Error updating splits') - - -def uwsgi_update_segments(user_config): - try: - config = _get_config(user_config) - seconds = config['segmentsRefreshRate'] - while True: - segment_cache = UWSGISegmentCache(get_uwsgi()) - sdk_api = api_factory(config) - segment_change_fetcher = ApiSegmentChangeFetcher(sdk_api) - update_segments(segment_cache, segment_change_fetcher) - - time.sleep(seconds) - except: - _logger.error('Error updating segments') - - -def uwsgi_report_impressions(user_config): - try: - config = _get_config(user_config) - seconds = config['impressionsRefreshRate'] - while True: - impressions_cache = UWSGIImpressionsCache(get_uwsgi()) - sdk_api = api_factory(config) - report_impressions( - impressions_cache, - sdk_api) - - time.sleep(seconds) - except: - _logger.error('Error posting impressions') - - -def uwsgi_report_metrics(user_config): - try: - config = _get_config(user_config) - seconds = config['metricsRefreshRate'] - while True: - metrics_cache = UWSGIMetricsCache(get_uwsgi()) - sdk_api = api_factory(config) - report_metrics(metrics_cache, sdk_api) - - time.sleep(seconds) - except: - _logger.error('Error posting metrics') - - -def uwsgi_report_events(user_config): - try: - config = _get_config(user_config) - seconds = config.get('eventsRefreshRate', 30) - events_cache = UWSGIEventsCache(get_uwsgi()) - sdk_api = api_factory(config) - task = EventsSyncTask(sdk_api, events_cache, seconds, 500) - while True: - task._send_events() - for _ in xrange(0, seconds): - if uwsgi.cache_get(UWSGIEventsCache._EVENTS_FLUSH, _SPLITIO_LOCK_CACHE_NAMESPACE): - uwsgi.cache_del(UWSGIEventsCache._EVENTS_FLUSH, _SPLITIO_LOCK_CACHE_NAMESPACE) - break - time.sleep(1) - except: - _logger.error('Error posting metrics') - - - - uwsgi_report_events(user_config) - uwsgi_report_metrics(user_config) - - -class UWSGISplitCache(SplitCache): - _KEY_TEMPLATE = 'split.{suffix}' - _KEY_TILL_TEMPLATE = 'splits.till' - _KEY_FEATURE_LIST_LOCK = 'splits.list.lock' - _KEY_FEATURE_LIST = 'splits.list' - _OVERWRITE_LOCK_SECONDS = 5 - - def __init__(self, adapter): - """A SplitCache implementation that uses uwsgi cache as its back-end.""" - self._adapter = adapter - - def is_enabled(self): - """Returns if uwsgi is enabled or not""" - return True - - def disable(self): - """Disable cache. To keep interface""" - return True - - def add_split(self, split_name, split): - """ - Stores a Split under a name. - :param split_name: Name of the split (feature) - :type split_name: str - :param split: The split to store - :type split: Split - """ - self._adapter.cache_update( - self._KEY_TEMPLATE.format(suffix=split_name), - encode(split), - 0, - _SPLITIO_SPLITS_CACHE_NAMESPACE - ) - - - def remove_split(self, split_name): - """ - Evicts a Split from the cache. - :param split_name: Name of the split (feature) - :type split_name: str - """ - return self._adapter.cache_del(self._KEY_TEMPLATE.format(suffix=split_name), _SPLITIO_SPLITS_CACHE_NAMESPACE) - - def get_split(self, split_name): - """ - Retrieves a Split from the cache. - :param split_name: Name of the split (feature) - :type split_name: str - :return: The split under the name if it exists, None otherwise - :rtype: Split - """ - to_decode = self._adapter.cache_get(self._KEY_TEMPLATE.format(suffix=split_name), _SPLITIO_SPLITS_CACHE_NAMESPACE) - - if to_decode is None: - return None - - to_decode = bytes_to_string(to_decode) - - split_dump = decode(to_decode) - - if split_dump is not None: - segment_cache = UWSGISegmentCache(self._adapter) - split_parser = UWSGISplitParser(segment_cache) - split = split_parser.parse(split_dump) - return split - - return None - - def get_splits_keys(self): - if self._adapter.cache_exists(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE): - try: - return list(decode( - self._adapter.cache_get(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE) - )) - except TypeError: # Thrown by jsonpickle.decode when passed "None" - pass # Fall back to default return statement (empty dict) - return [] - - def get_splits(self): - current_splits = self.get_splits_keys() - - to_return = [] - - for split_name in current_splits: - to_return.append(self.get_split(split_name)) - - return to_return - - - def set_change_number(self, change_number): - """ - Sets the value for the change number - :param change_number: The change number - :type change_number: int - """ - return self._adapter.cache_update(self._KEY_TILL_TEMPLATE, encode(change_number), 0, _SPLITIO_CHANGE_NUMBERS) - - def get_change_number(self): - """ - Retrieves the value of the change number - :return: The current change number value, -1 otherwise - :rtype: int - """ - try: - return decode(self._adapter.cache_get(self._KEY_TILL_TEMPLATE, _SPLITIO_CHANGE_NUMBERS)) - except TypeError: - return -1 - - def update_split_list(self, added, removed): - """ - Updates a list of splits that will be used to keep track of impression keys in the cache. - """ - added_set = set(added) - removed_set = set(removed) - with UWSGILock(self._KEY_FEATURE_LIST_LOCK): - try: - current = decode(self._adapter.cache_get(self._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE)) - except TypeError: - current = set() - current = (current.union(added_set)).difference(removed_set) - self._adapter.cache_update( - self._KEY_FEATURE_LIST, - encode(current), - 0, - _SPLITIO_MISC_NAMESPACE - ) - - -class UWSGISegmentCache(SegmentCache): - _KEY_TEMPLATE = 'segments.{suffix}' - _SEGMENT_DATA_KEY_TEMPLATE = 'segmentData.{segment_name}' - _SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE = 'segment.{segment_name}.till' - _SEGMENT_REGISTERED = _KEY_TEMPLATE.format(suffix='registered') - - - def __init__(self,adapter, disabled_period=300): - """A Segment Cache implementation that uses uWSGI as its back-end - :param adapter: The uwsgi module - :rtype uwsgi: uwsgi - :param disabled_period: The expiration period for the disabled key. - :param disabled_period: int - """ - self._adapter = adapter - self._disabled_period = disabled_period - - @property - def disabled_period(self): - return self._disabled_period - - @disabled_period.setter - def disabled_period(self, disabled_period): - self._disabled_period = disabled_period - - def disable(self): - """Disables the automatic update process. This method will be called if the update fails - for some reason. Use enable to re-enable the update process.""" - pass - - def enable(self): - """Enables the automatic update process.""" - pass - - def is_enabled(self): - """ - :return: Whether the update process is enabled or not. - :rtype: bool - """ - return True - - def register_segment(self, segment_name): - """Register a segment for inclusion in the automatic update process. - :param segment_name: Name of the segment. - :type segment_name: str - """ - try: - segments = decode(self._adapter.cache_get(self._SEGMENT_REGISTERED, _SPLITIO_MISC_NAMESPACE)) - except TypeError: - segments = set() - - segments.add(segment_name) - self._adapter.cache_update(self._SEGMENT_REGISTERED, encode(segments), 0, _SPLITIO_MISC_NAMESPACE) - - def unregister_segment(self, segment_name): - """Unregister a segment from the automatic update process. - :param segment_name: Name of the segment. - :type segment_name: str - """ - - try: - segments = decode(self._adapter.cache_get(self._SEGMENT_REGISTERED, _SPLITIO_MISC_NAMESPACE)) - #If segment is in set, remove it and update cache - if segment_name in segments: - segments.discard(segment_name) - self._adapter.cache_update(self._SEGMENT_REGISTERED, encode(segments), 0, _SPLITIO_MISC_NAMESPACE) - except TypeError: - pass - - def get_registered_segments(self): - """ - :return: All segments included in the automatic update process. - :rtype: set - """ - try: - return decode(self._adapter.cache_get(self._SEGMENT_REGISTERED, _SPLITIO_MISC_NAMESPACE)) - except TypeError: - return set() - - def add_keys_to_segment(self, segment_name, segment_keys): - _key = self._SEGMENT_DATA_KEY_TEMPLATE.format(segment_name=segment_name) - try: - segment_data = decode(self._adapter.cache_get(_key, _SPLITIO_SEGMENTS_CACHE_NAMESPACE)) - except TypeError: - segment_data = set() - - segment_data.update(segment_keys) - self._adapter.cache_update(_key, encode(segment_data), 0, _SPLITIO_SEGMENTS_CACHE_NAMESPACE) - - - def remove_keys_from_segment(self, segment_name, segment_keys): - _key = self._SEGMENT_DATA_KEY_TEMPLATE.format(segment_name=segment_name) - try: - segment_data = decode(self._adapter.cache_get(_key, _SPLITIO_SEGMENTS_CACHE_NAMESPACE)) - for segment_key in segment_keys: - segment_data.discard(segment_key) - self._adapter.cache_update(_key, encode(segment_data), 0, _SPLITIO_SEGMENTS_CACHE_NAMESPACE) - except TypeError: - pass - - def is_in_segment(self, segment_name, key): - _key = self._SEGMENT_DATA_KEY_TEMPLATE.format(segment_name=segment_name) - try: - segment_data = decode(self._adapter.cache_get(_key, _SPLITIO_SEGMENTS_CACHE_NAMESPACE)) - if key in segment_data: - return True - except TypeError: - pass - - return False - - def set_change_number(self, segment_name, change_number): - self._adapter.cache_update( - self._SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE.format(segment_name=segment_name), - encode(change_number), - 0, - _SPLITIO_CHANGE_NUMBERS - ) - - def get_change_number(self, segment_name): - try: - change_number = decode(self._adapter.cache_get( - self._SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE.format(segment_name=segment_name), - _SPLITIO_CHANGE_NUMBERS - )) - return int(change_number) if change_number is not None else -1 - except TypeError: - return -1 - - -class UWSGISplitParser(SplitParser): - def __init__(self, segment_cache): - """ - A SplitParser implementation that registers the segments with the uwsgi segment cache - implementation upon parsing an IN_SEGMENT matcher. - """ - super(UWSGISplitParser, self).__init__(None) - self._segment_cache = segment_cache - - def _parse_split(self, split, block_until_ready=False): - return UWSGISplit( - split['name'], split['seed'], split['killed'], - split['defaultTreatment'], split['trafficTypeName'], - split['status'], split['changeNumber'], - segment_cache=self._segment_cache, algo=split.get('algo'), - traffic_allocation=split.get('trafficAllocation'), - traffic_allocation_seed=split.get('trafficAllocationSeed') - ) - - def _parse_matcher_in_segment(self, partial_split, matcher, block_until_ready=False, *args, - **kwargs): - matcher_data = self._get_matcher_attribute('userDefinedSegmentMatcherData', matcher) - segment = UWSGISplitBasedSegment(matcher_data['segmentName'], partial_split) - delegate = UserDefinedSegmentMatcher(segment) - self._segment_cache.register_segment(delegate.segment.name) - return delegate - -class UWSGISplit(Split): - def __init__(self, name, seed, killed, default_treatment, traffic_type_name, status, change_number, conditions=None, segment_cache=None, algo=None, - traffic_allocation=None, - traffic_allocation_seed=None): - """A split implementation that mantains a reference to the segment cache so segments can - be easily pickled and unpickled. - :param name: Name of the feature - :type name: unicode - :param seed: Seed - :type seed: int - :param killed: Whether the split is killed or not - :type killed: bool - :param default_treatment: Default treatment for the split - :type default_treatment: str - :param conditions: Set of conditions to test - :type conditions: list - :param segment_cache: A segment cache - :type segment_cache: SegmentCache - """ - super(UWSGISplit, self).__init__( - name, seed, killed, default_treatment, traffic_type_name, status, - change_number, conditions, algo, traffic_allocation, - traffic_allocation_seed) - self._segment_cache = segment_cache - - @property - def segment_cache(self): - return self._segment_cache - - @segment_cache.setter - def segment_cache(self, segment_cache): - self._segment_cache = segment_cache - - def __getstate__(self): - old_dict = self.__dict__.copy() - del old_dict['_segment_cache'] - return old_dict - - def __setstate__(self, dict): - self.__dict__.update(dict) - self._segment_cache = None - - -class UWSGISplitBasedSegment(Segment): - def __init__(self, name, split): - """A Segment that uses a reference to a UWSGISplit uwsgi' instance to check if a key - is in a segment - :param name: The name of the segment - :type name: str - :param split: A UWSGISplit instance - :type split: UWSGISplit - """ - super(UWSGISplitBasedSegment, self).__init__(name) - self._split = split - - def contains(self, key): - return self._split.segment_cache.is_in_segment(self.name, key) - - -class UWSGIImpressionsCache(ImpressionsCache): - _IMPRESSIONS_KEY = 'impressions.{feature}' - _LOCK_IMPRESSION_KEY = 'impressions_lock.{feature}' - _MISSING = '__MISSING__' - _OVERWRITE_LOCK_SECONDS = 5 - - def __init__(self, adapter, disabled_period=300): - """An ImpressionsCache implementation that uses uWSGI as its back-end - :param disabled_period: The expiration period for the disabled key. - :param disabled_period: int - """ - self._adapter = adapter - self._disabled_period = disabled_period - - @property - def disabled_period(self): - return self._disabled_period - - @disabled_period.setter - def disabled_period(self, disabled_period): - self._disabled_period = disabled_period - - def enable(self): - """Enables the automatic impressions report process and the registration of impressions.""" - pass - - def disable(self): - """Disables the automatic impressions report process and the registration of any - impressions for the specificed disabled period. This method will be called if there's an - exception while trying to send the impressions back to Split.""" - pass - - def is_enabled(self): - """ - :return: Whether the automatic report process and impressions registration are enabled. - :rtype: bool - """ - return True - - def _build_impressions_dict(self, impressions): - """Buils a dictionary of impressions that groups them based on their feature name. - :param impressions: List of impression tuples - :type impressions: list - :return: Dictionary of impressions grouped by feature name - :rtype: dict - """ - sorted_impressions = sorted(impressions, key=lambda impression: impression.feature_name) - grouped_impressions = groupby(sorted_impressions, - key=lambda impression: impression.feature_name) - return dict((feature_name, list(group)) for feature_name, group in grouped_impressions) - - def fetch_all(self): - """Fetches all impressions from the cache. It returns a dictionary with the impressions - grouped by feature name. - :return: All cached impressions so far grouped by feature name - :rtype: dict - """ - return self.fetch_all_and_clear() - - def clear(self): - """Clears all cached impressions""" - pass - - def add_impression(self, impression): - """Adds an impression to the log if it is enabled, otherwise the impression is dropped. - :param impression: The impression tuple - :type impression: Impression - """ - features = self._adapter.cache_get(UWSGISplitCache._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE) - try: - features = decode(features) - except TypeError: - features = set() - - if impression.feature_name in features: - # Feature is known, add to it's own impression collection - key = self._IMPRESSIONS_KEY.format(feature=impression.feature_name) - lock_key = self._LOCK_IMPRESSION_KEY.format(feature=impression.feature_name) - else: - # Feature is unknown add to `impressions.__MISSING__` key in cache - key = self._IMPRESSIONS_KEY.format(feature=self._MISSING) - lock_key = self._LOCK_IMPRESSION_KEY.format(feature=self._MISSING) - - with UWSGILock(lock_key): - try: - impressions = decode(self._adapter.cache_get(key, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE)) - except TypeError: - impressions = set() - - impressions.add(tuple(impression)) - self._adapter.cache_update(key, encode(impressions), 0, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE) - - - def fetch_all_and_clear(self): - """Fetches all impressions from the cache and clears it. It returns a dictionary with the - impressions grouped by feature name. - :return: All cached impressions so far grouped by feature name - :rtype: dict - """ - features = self._adapter.cache_get(UWSGISplitCache._KEY_FEATURE_LIST, _SPLITIO_MISC_NAMESPACE) - try: - features = decode(features) - except TypeError: - features = set() - - # Include impressions for splits not in cache. - features.add(self._MISSING) - - impressions = [] - for feature in features: - key = self._IMPRESSIONS_KEY.format(feature=feature) - lock_key = self._LOCK_IMPRESSION_KEY.format(feature=feature) - with UWSGILock(lock_key): - raw = self._adapter.cache_get(key, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE) - self._adapter.cache_del(key, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE) - - try: - impressions.extend([Impression(*i) for i in decode(raw)]) - except TypeError: - pass - - return self._build_impressions_dict(impressions) - - -class UWSGIEventsCache: - _EVENTS_KEY = 'events' - _LOCK_EVENTS_KEY = 'events_lock' - _EVENTS_FLUSH = 'events_flush' - _OVERWRITE_LOCK_SECONDS = 5 - - def __init__(self, adapter, disabled_period=300, events_queue_size=500): - """An ImpressionsCache implementation that uses uWSGI as its back-end - :param disabled_period: The expiration period for the disabled key. - :param disabled_period: int - """ - self._adapter = adapter - self._disabled_period = disabled_period - self._events_queue_size = events_queue_size - - @property - def disabled_period(self): - return self._disabled_period - - @disabled_period.setter - def disabled_period(self, disabled_period): - self._disabled_period = disabled_period - - def log_event(self, event): - """Adds an impression to the log if it is enabled, otherwise the impression is dropped. - :param impression: The impression tuple - :type impression: Impression - """ - cache_event = dict(event._asdict()) - - with UWSGILock(self._LOCK_EVENTS_KEY): - try: - events = decode(self._adapter.cache_get(self._EVENTS_KEY, _SPLITIO_EVENTS_CACHE_NAMESPACE)) - except TypeError: - events = [] - - if len(events) < self._events_queue_size: - events.append(cache_event) - _logger.debug('Adding event to cache: {}.'.format(event)) - self._adapter.cache_update(self._EVENTS_KEY, encode(events), 0, _SPLITIO_EVENTS_CACHE_NAMESPACE) - return True - - # Set a key to force an events flush - self._adapter.cache_set(self._EVENTS_FLUSH, '1', 0, _SPLITIO_LOCK_CACHE_NAMESPACE) - return False - - def pop_many(self, count): - """Fetches all impressions from the cache and clears it. It returns a dictionary with the - impressions grouped by feature name. - :return: All cached impressions so far grouped by feature name - :rtype: dict - """ - try: - with UWSGILock(self._LOCK_EVENTS_KEY): - cached_events = decode(self._adapter.cache_get(self._EVENTS_KEY, _SPLITIO_EVENTS_CACHE_NAMESPACE)) - events_to_return = cached_events[(0 - count):] - cached_events = cached_events[:(0 - count)] - self._adapter.cache_update(self._EVENTS_KEY, encode(cached_events), 0, _SPLITIO_EVENTS_CACHE_NAMESPACE) - events = [Event(**e) for e in events_to_return] - return events - except TypeError: - return [] - - -class UWSGIMetricsCache(MetricsCache): - _KEY_TEMPLATE = 'metrics.{suffix}' - _METRIC_KEY = _KEY_TEMPLATE.format(suffix='metric') - _LATENCY_KEY = _KEY_TEMPLATE.format(suffix='latency') - _KEY_LATENCY_BUCKET = 'latency.{metric_name}.bucket.{bucket_number}' - _COUNT_FIELD_TEMPLATE = 'count.{counter}' - _TIME_FIELD_TEMPLATE = 'time.{operation}.{bucket_index}' - _GAUGE_FIELD_TEMPLATE = 'gauge.{gauge}' - - _LATENCY_FIELD_RE = re.compile('^latency\.(?P.+)\.bucket\.(?P.+)$') - _COUNT_FIELD_RE = re.compile('^count\.(?P.+)$') - _TIME_FIELD_RE = re.compile('^time\.(?P.+)\.(?P.+)$') - _GAUGE_FIELD_RE = re.compile('^gauge\.(?P.+)$') - - def __init__(self, adapter, disabled_period=300): - """A MetricsCache implementation that uses uWSGI as its back-end - :param disabled_period: The expiration period for the disabled key. - :param disabled_period: int - """ - super(UWSGIMetricsCache, self).__init__() - self._adapter = adapter - self._disabled_period = disabled_period - - @property - def disabled_period(self): - return self._disabled_period - - @disabled_period.setter - def disabled_period(self, disabled_period): - self._disabled_period = disabled_period - - def enable(self): - """Enables the automatic metrics report process and the registration of new metrics.""" - pass - - def disable(self): - """Disables the automatic metrics report process and the registration of any - metrics for the specified disabled period. This method will be called if there's an - exception while trying to send the metrics back to Split.""" - pass - - def is_enabled(self): - """ - :return: Whether the automatic report process and metrics registration are enabled. - :rtype: bool - """ - return True - - def _get_count_field(self, counter): - """Builds the field name for a counter on the metrics. - :param counter: Name of the counter - :type counter: str - :return: Name of the field on the metrics hash for the given counter - :rtype: str - """ - return self._COUNT_FIELD_TEMPLATE.format(counter=counter) - - def _get_time_field(self, operation, bucket_index): - """Builds the field name for a latency counting bucket ont the metrics. - :param operation: Name of the operation - :type operation: str - :param bucket_index: Latency bucket index as returned by get_latency_bucket_index - :type bucket_index: int - :return: Name of the field on the metrics hash for the latency bucket counter - :rtype: str - """ - return self._TIME_FIELD_TEMPLATE.format(operation=operation, - bucket_index=bucket_index) - - def _get_all_buckets_time_fields(self, operation): - """ Builds a list of all the fields in the metrics hash for the latency buckets for a given - operation. - :param operation: Name of the operation - :type operation: str - :return: List of field names - :rtype: list - """ - return [self._get_time_field(operation, bucket) for bucket in range(0, len(BUCKETS))] - - def _get_gauge_field(self, gauge): - """Builds the field name for a gauge on the metrics hash. - :param gauge: Name of the gauge - :type gauge: str - :return: Name of the field on the metrics hash for the given gauge - :rtype: str - """ - return self._GAUGE_FIELD_TEMPLATE.format(gauge=gauge) - - def _build_metrics_counter_data(self, count_metrics): - """Build metrics counter data in the format expected by the API from the contents of the - cache. - :param count_metrics: A dictionary of name/value counter metrics - :param count_metrics: dict - :return: A list of of counter metrics - :rtype: list - """ - return [{'name': name, 'delta': delta} for name, delta in iteritems(count_metrics)] - - def _build_metrics_times_data(self, time_metrics): - """Build metrics times data in the format expected by the API from the contents of the - cache. - :param time_metrics: A dictionary of name/latencies time metrics - :param time_metrics: dict - :return: A list of of time metrics - :rtype: list - """ - to_return = [{'name': name, 'latencies': latencies} - for name, latencies in iteritems(time_metrics)] - return to_return - - def _build_metrics_gauge_data(self, gauge_metrics): - """Build metrics gauge data in the format expected by the API from the contents of the - cache. - :param gauge_metrics: A dictionary of name/value gauge metrics - :param gauge_metrics: dict - :return: A list of of gauge metrics - :rtype: list - """ - return [{'name': name, 'value': value} for name, value in iteritems(gauge_metrics)] - - def _build_metrics_from_cache_response(self, response): - """Builds a dictionary with time, count and gauge metrics based on the result of calling - fetch_all_and_clear (list of name/value pairs). Each entry in the dictionary is in the - format accepted by the events API. - :param response: Response given by the fetch_all_and_clear method - :type response: lsit - :return: Dictionary with time, count and gauge metrics - :rtype: dict - """ - if response is None: - return {'count': [], 'gauge': []} - - count = dict() - gauge = dict() - - for field, value in response.items(): - count_match = self._COUNT_FIELD_RE.match(field) - if count_match is not None: - count[count_match.group('counter')] = value - continue - - gauge_match = self._GAUGE_FIELD_RE.match(field) - if gauge_match is not None: - gauge[gauge_match.group('gauge')] = value - continue - - return { - 'count': self._build_metrics_counter_data(count), - 'gauge': self._build_metrics_gauge_data(gauge) - } - - - def _get_metric(self, field_name): - try: - metrics = decode(self._adapter.cache_get(self._METRIC_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - if field_name in metrics: - return metrics[field_name] - except TypeError: - pass - - return None - - def _set_metric(self, field_name, value): - try: - metrics = decode(self._adapter.cache_get(self._METRIC_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - except TypeError: - metrics = dict() - - metrics[field_name] = value - self._adapter.cache_update(self._METRIC_KEY, encode(metrics), 0, _SPLITIO_METRICS_CACHE_NAMESPACE) - _logger.error(metrics) - - def get_latency(self, operation): - _latencies = [] - try: - latencies = decode(self._adapter.cache_get(self._LATENCY_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - for bucket in range(0, len(BUCKETS)): - _key = self._KEY_LATENCY_BUCKET.format(metric_name=operation, bucket_number=bucket) - if _key in latencies: - _latencies.append(latencies[_key]) - else: - _latencies.append(0) - return _latencies - except TypeError: - return [0 for bucket in range(0, len(BUCKETS))] - - def get_latency_bucket_counter(self, operation, bucket_index): - try: - latencies = decode(self._adapter.cache_get(self._LATENCY_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - _key = self._KEY_LATENCY_BUCKET.format(metric_name=operation, bucket_number=bucket_index) - if _key in latencies: - return latencies[_key] - except TypeError: - return 0 - - def set_latency_bucket_counter(self, operation, bucket_index, value): - try: - latencies = decode(self._adapter.cache_get(self._LATENCY_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - except TypeError: - latencies = {} - latencies[self._KEY_LATENCY_BUCKET.format(metric_name=operation, bucket_number=bucket_index)] = value - self._adapter.cache_update(self._LATENCY_KEY, encode(latencies), 0, _SPLITIO_METRICS_CACHE_NAMESPACE) - - def increment_latency_bucket_counter(self, operation, bucket_index, delta=1): - latency = self.get_latency_bucket_counter(operation, bucket_index) - self.set_latency_bucket_counter(operation, bucket_index, latency + delta) - - def set_count(self, counter, value): - metric_field = self._get_count_field(counter) - self._set_metric(metric_field, value) - - def get_count(self, counter): - value = self._get_metric(self._get_count_field(counter)) - if value is not None: - return value - return 0 - - def increment_count(self, counter, delta=1): - counter_value = self.get_count(counter) + delta - self.set_count(counter, counter_value) - - def set_gauge(self, gauge, value): - gauge_field = self._get_gauge_field(gauge) - self._set_metric(gauge_field, value) - - def get_gauge(self, gauge): - value = self._get_metric(self._get_gauge_field(gauge)) - if value is not None: - return value - return 0 - - def fetch_all_and_clear(self): - try: - metrics = decode(self._adapter.cache_get(self._METRIC_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - return self._build_metrics_from_cache_response(metrics) - except TypeError: - return self._build_metrics_from_cache_response(None) - - def fetch_all_times_and_clear(self): - try: - latencies = decode(self._adapter.cache_get(self._LATENCY_KEY, _SPLITIO_METRICS_CACHE_NAMESPACE)) - time = defaultdict(lambda: [0] * len(BUCKETS)) - for key in latencies: - time_match = self._LATENCY_FIELD_RE.match(key) - if time_match is not None: - time[time_match.group('operation')][int(time_match.group('bucket_index'))] = int(latencies[key]) - latencies[key] = 0 - self._adapter.cache_update(self._LATENCY_KEY, encode(latencies), 0, _SPLITIO_METRICS_CACHE_NAMESPACE) - return self._build_metrics_times_data(time) - except TypeError: - return self._build_metrics_times_data({}) - - -class UWSGICacheEmulator(object): - def __init__(self): - """ - UWSGI Cache Emulator for unit tests. Implements uwsgi cache framework interface - http://uwsgi-docs.readthedocs.io/en/latest/Caching.html#accessing-the-cache-from-your-applications-using-the-cache-api - """ - self._cache = dict() - - def _check_string_data_type(self, value): - if type(value).__name__ == 'str': - return True - raise TypeError('The value to add into uWSGI cache must be string and %s given' % type(value).__name__) - - def cache_get(self, key, cache_namespace='default'): - if self.cache_exists(key, cache_namespace): - return self._cache[cache_namespace][key] - return None - - def cache_set(self, key, value, expires=0, cache_namespace='default'): - self._check_string_data_type(value) - - if cache_namespace in self._cache: - self._cache[cache_namespace][key] = value - else: - self._cache[cache_namespace] = {key:value} - - def cache_update(self, key,value, expires=0, cache_namespace='default'): - self.cache_set(key, value, expires, cache_namespace) - - def cache_exists(self, key, cache_namespace='default'): - if cache_namespace in self._cache: - if key in self._cache[cache_namespace]: - return True - return False - - def cache_del(self, key, cache_namespace='default'): - if cache_namespace in self._cache: - self._cache[cache_namespace].pop(key, None) - - def cache_clear(self, cache_namespace='default'): - self._cache.pop(cache_namespace, None) - - -def get_uwsgi(emulator=False): - """Returns a uwsgi imported module or an emulator to use in unit test """ - if emulator: - return UWSGICacheEmulator() - - return uwsgi diff --git a/splitio/version.py b/splitio/version.py index 2f21dd16..775951e9 100644 --- a/splitio/version.py +++ b/splitio/version.py @@ -1 +1 @@ -__version__ = '7.0.1' +__version__ = '8.0.0-rc1' diff --git a/tests/api/test_events.py b/tests/api/test_events.py new file mode 100644 index 00000000..0947c4ed --- /dev/null +++ b/tests/api/test_events.py @@ -0,0 +1,57 @@ +"""Impressions API tests module.""" + +import pytest +from splitio.api import events, client, APIException +from splitio.models.events import Event +from splitio.client.util import SdkMetadata + + +class EventsAPITests(object): + """Impressions API test cases.""" + + def test_post_events(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '') + sdk_metadata = SdkMetadata('python-1.2.3', 'some_machine_name', '123.123.123.123') + events_api = events.EventsAPI(httpclient, 'some_api_key', sdk_metadata) + response = events_api.flush_events([ + Event('k1', 'user', 'purchase', 12.50, 123456), + Event('k2', 'user', 'purchase', 12.50, 123456), + Event('k3', 'user', 'purchase', None, 123456), + Event('k4', 'user', 'purchase', None, 123456) + ]) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('events', '/events/bulk', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-1.2.3', + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == [ + {'key': 'k1', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': 12.50, 'timestamp': 123456}, + {'key': 'k2', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': 12.50, 'timestamp': 123456}, + {'key': 'k3', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': None, 'timestamp': 123456}, + {'key': 'k4', 'trafficTypeName': 'user', 'eventTypeId': 'purchase', 'value': None, 'timestamp': 123456} + ] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = events_api.flush_events([ + Event('k1', 'user', 'purchase', 12.50, 123456), + Event('k2', 'user', 'purchase', 12.50, 123456), + Event('k3', 'user', 'purchase', None, 123456), + Event('k4', 'user', 'purchase', None, 123456) + ]) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_httpclient.py b/tests/api/test_httpclient.py new file mode 100644 index 00000000..694c9a22 --- /dev/null +++ b/tests/api/test_httpclient.py @@ -0,0 +1,139 @@ +"""HTTPClient test module.""" + +from splitio.api import client + +class HttpClientTests(object): + """Http Client test cases.""" + + def test_get(self, mocker): + """Test HTTP GET verb requests.""" + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.api.client.requests.get', new=get_mock) + httpclient = client.HttpClient() + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.HttpClient.SDK_URL + '/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + response = httpclient.get('events', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.HttpClient.EVENTS_URL + '/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert get_mock.mock_calls == [call] + assert response.status_code == 200 + assert response.body == 'ok' + + def test_get_custom_urls(self, mocker): + """Test HTTP GET verb requests.""" + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.api.client.requests.get', new=get_mock) + httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') + response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert get_mock.mock_calls == [call] + assert response.status_code == 200 + assert response.body == 'ok' + get_mock.reset_mock() + + response = httpclient.get('events', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://events.com/test1', + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + + + def test_post(self, mocker): + """Test HTTP GET verb requests.""" + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.api.client.requests.post', new=get_mock) + httpclient = client.HttpClient() + response = httpclient.post('sdk', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.HttpClient.SDK_URL + '/test1', + json={'p1': 'a'}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + response = httpclient.post('events', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + client.HttpClient.EVENTS_URL + '/test1', + json={'p1': 'a'}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + + def test_post_custom_urls(self, mocker): + """Test HTTP GET verb requests.""" + response_mock = mocker.Mock() + response_mock.status_code = 200 + response_mock.text = 'ok' + get_mock = mocker.Mock() + get_mock.return_value = response_mock + mocker.patch('splitio.api.client.requests.post', new=get_mock) + httpclient = client.HttpClient(sdk_url='https://sdk.com', events_url='https://events.com') + response = httpclient.post('sdk', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://sdk.com' + '/test1', + json={'p1': 'a'}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] + get_mock.reset_mock() + + response = httpclient.post('events', '/test1', 'some_api_key', {'p1': 'a'}, {'param1': 123}, {'h1': 'abc'}) + call = mocker.call( + 'https://events.com' + '/test1', + json={'p1': 'a'}, + headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'}, + params={'param1': 123}, + timeout=None + ) + assert response.status_code == 200 + assert response.body == 'ok' + assert get_mock.mock_calls == [call] diff --git a/tests/api/test_impressions_api.py b/tests/api/test_impressions_api.py new file mode 100644 index 00000000..9bf5467e --- /dev/null +++ b/tests/api/test_impressions_api.py @@ -0,0 +1,63 @@ +"""Impressions API tests module.""" + +import pytest +from splitio.api import impressions, client, APIException +from splitio.models.impressions import Impression +from splitio.client.util import SdkMetadata + +class ImpressionsAPITests(object): + """Impressions API test cases.""" + + def test_post_impressions(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '') + sdk_metadata = SdkMetadata('python-1.2.3', 'some_machine_name', '123.123.123.123') + impressions_api = impressions.ImpressionsAPI(httpclient, 'some_api_key', sdk_metadata) + response = impressions_api.flush_impressions([ + Impression('k1', 'f1', 'on', 'l1', 123456, 'b1', 321654), + Impression('k2', 'f2', 'off', 'l1', 123456, 'b1', 321654), + Impression('k3', 'f1', 'on', 'l1', 123456, 'b1', 321654), + ]) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('events', '/testImpressions/bulk', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-1.2.3', + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == [ + { + 'testName': 'f1', + 'keyImpressions': [ + {'keyName': 'k1', 'bucketingKey': 'b1', 'treatment': 'on', 'label': 'l1', 'time': 321654, 'changeNumber': 123456}, + {'keyName': 'k3', 'bucketingKey': 'b1', 'treatment': 'on', 'label': 'l1', 'time': 321654, 'changeNumber': 123456}, + ], + }, + { + 'testName': 'f2', + 'keyImpressions': [ + {'keyName': 'k2', 'bucketingKey': 'b1', 'treatment': 'off', 'label': 'l1', 'time': 321654, 'changeNumber': 123456}, + ] + } + ] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = impressions_api.flush_impressions([ + Impression('k1', 'f1', 'on', 'l1', 123456, 'b1', 321654), + Impression('k2', 'f2', 'off', 'l1', 123456, 'b1', 321654), + Impression('k3', 'f1', 'on', 'l1', 123456, 'b1', 321654), + ]) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_segments_api.py b/tests/api/test_segments_api.py new file mode 100644 index 00000000..0344fd23 --- /dev/null +++ b/tests/api/test_segments_api.py @@ -0,0 +1,27 @@ +"""Segment API tests module.""" + +import pytest +from splitio.api import segments, client, APIException + + +class SegmentAPITests(object): + """Segment API test cases.""" + + def test_fetch_segment_changes(self, mocker): + """Test segment changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.get.return_value = client.HttpResponse(200, '{"prop1": "value1"}') + segment_api = segments.SegmentsAPI(httpclient, 'some_api_key') + response = segment_api.fetch_segment('some_segment', 123) + + assert response['prop1'] == 'value1' + assert httpclient.get.mock_calls == [mocker.call('sdk', '/segmentChanges/some_segment', 'some_api_key', {'since': 123})] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.get.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = segment_api.fetch_segment('some_segment', 123) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_splits_api.py b/tests/api/test_splits_api.py new file mode 100644 index 00000000..a3e17206 --- /dev/null +++ b/tests/api/test_splits_api.py @@ -0,0 +1,27 @@ +"""Split API tests module.""" + +import pytest +from splitio.api import splits, client, APIException + + +class SplitAPITests(object): + """Split API test cases.""" + + def test_fetch_split_changes(self, mocker): + """Test split changes fetching API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.get.return_value = client.HttpResponse(200, '{"prop1": "value1"}') + split_api = splits.SplitsAPI(httpclient, 'some_api_key') + response = split_api.fetch_splits(123) + + assert response['prop1'] == 'value1' + assert httpclient.get.mock_calls == [mocker.call('sdk', '/splitChanges', 'some_api_key', {'since': 123})] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.get.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = split_api.fetch_splits(123) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/api/test_telemetry.py b/tests/api/test_telemetry.py new file mode 100644 index 00000000..60c6f54c --- /dev/null +++ b/tests/api/test_telemetry.py @@ -0,0 +1,118 @@ +"""Telemetry API tests module.""" + +import pytest +from splitio.api import telemetry, client, APIException +from splitio.client.util import SdkMetadata + + +class EventsAPITests(object): + """Impressions API test cases.""" + + def test_post_latencies(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '') + sdk_metadata = SdkMetadata('python-1.2.3', 'some_machine_name', '123.123.123.123') + telemetry_api = telemetry.TelemetryAPI(httpclient, 'some_api_key', sdk_metadata) + response = telemetry_api.flush_latencies({ + 'l1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + }) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('events', '/metrics/times', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-1.2.3', + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == [{ + 'name': 'l1', + 'latencies': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + }] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = telemetry_api.flush_latencies({ + 'l1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + }) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + def test_post_counters(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '') + sdk_metadata = SdkMetadata('python-1.2.3', 'some_machine_name', '123.123.123.123') + telemetry_api = telemetry.TelemetryAPI(httpclient, 'some_api_key', sdk_metadata) + response = telemetry_api.flush_counters({'counter1': 1, 'counter2': 2}) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('events', '/metrics/counters', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-1.2.3', + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == [ + {'name': 'counter1', 'delta': 1}, + {'name': 'counter2', 'delta': 2} + ] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = telemetry_api.flush_counters({'counter1': 1, 'counter2': 2}) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' + + def test_post_gauge(self, mocker): + """Test impressions posting API call.""" + httpclient = mocker.Mock(spec=client.HttpClient) + httpclient.post.return_value = client.HttpResponse(200, '') + sdk_metadata = SdkMetadata('python-1.2.3', 'some_machine_name', '123.123.123.123') + telemetry_api = telemetry.TelemetryAPI(httpclient, 'some_api_key', sdk_metadata) + response = telemetry_api.flush_gauges({'gauge1': 1, 'gauge2': 2}) + + call_made = httpclient.post.mock_calls[0] + + # validate positional arguments + assert call_made[1] == ('events', '/metrics/gauge', 'some_api_key') + + # validate key-value args (headers) + assert call_made[2]['extra_headers'] == { + 'SplitSDKVersion': 'python-1.2.3', + 'SplitSDKMachineIP': '123.123.123.123', + 'SplitSDKMachineName': 'some_machine_name' + } + + # validate key-value args (body) + assert call_made[2]['body'] == [ + {'name': 'gauge1', 'value': 1}, + {'name': 'gauge2', 'value': 2} + ] + + httpclient.reset_mock() + def raise_exception(*args, **kwargs): + raise client.HttpClientException('some_message', Exception('something')) + httpclient.post.side_effect = raise_exception + with pytest.raises(APIException) as exc_info: + response = telemetry_api.flush_gauges({'gauge1': 1, 'gauge2': 2}) + assert exc_info.type == APIException + assert exc_info.value.message == 'some_message' diff --git a/tests/client/test_client.py b/tests/client/test_client.py new file mode 100644 index 00000000..4486931e --- /dev/null +++ b/tests/client/test_client.py @@ -0,0 +1,189 @@ +"""SDK main client test module.""" +#pylint: disable=no-self-use,protected-access + +from splitio.client.client import Client +from splitio.client.factory import SplitFactory +from splitio.engine.evaluator import Evaluator +from splitio.models.impressions import Impression +from splitio.models.events import Event +from splitio.storage import EventStorage, ImpressionStorage, SegmentStorage, SplitStorage, \ + TelemetryStorage + +class ClientTests(object): #pylint: disable=too-few-public-methods + """Split client test cases.""" + + def test_get_treatment(self, mocker): + """Test get_treatment execution paths.""" + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + telemetry_storage = mocker.Mock(spec=TelemetryStorage) + def _get_storage_mock(name): + return { + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'events': event_storage, + 'telemetry': telemetry_storage + }[name] + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = mocker.Mock(spec=SplitFactory) + factory._get_storage.side_effect = _get_storage_mock + type(factory).destroyed = destroyed_property + + mocker.patch('splitio.client.client.time.time', new=lambda: 1) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, True, None) + client._evaluator = mocker.Mock(spec=Evaluator) + client._evaluator.evaluate_treatment.return_value = { + 'treatment': 'on', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + } + } + client._logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + + assert client.get_treatment('some_key', 'some_feature') == 'on' + assert mocker.call( + [Impression('some_key', 'some_feature', 'on', 'some_label', 123, None, 1000)] + ) in impression_storage.put.mock_calls + assert mocker.call('sdk.getTreatment', 5) in telemetry_storage.inc_latency.mock_calls + assert client._logger.mock_calls == [] + assert mocker.call( + Impression('some_key', 'some_feature', 'on', 'some_label', 123, None, 1000), + None + ) in client._send_impression_to_listener.mock_calls + + # Test with exception: + split_storage.get_change_number.return_value = -1 + def _raise(*_): + raise Exception('something') + client._evaluator.evaluate_treatment.side_effect = _raise + assert client.get_treatment('some_key', 'some_feature') == 'control' + assert mocker.call( + [Impression('some_key', 'some_feature', 'control', 'exception', -1, None, 1000)] + ) in impression_storage.put.mock_calls + assert len(telemetry_storage.inc_latency.mock_calls) == 2 + + def test_get_treatments(self, mocker): + """Test get_treatment execution paths.""" + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + telemetry_storage = mocker.Mock(spec=TelemetryStorage) + def _get_storage_mock(name): + return { + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'events': event_storage, + 'telemetry': telemetry_storage + }[name] + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = mocker.Mock(spec=SplitFactory) + factory._get_storage.side_effect = _get_storage_mock + type(factory).destroyed = destroyed_property + + mocker.patch('splitio.client.client.time.time', new=lambda: 1) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, True, None) + client._evaluator = mocker.Mock(spec=Evaluator) + client._evaluator.evaluate_treatment.return_value = { + 'treatment': 'on', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + } + } + client._logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert client.get_treatments('key', ['f1', 'f2']) == {'f1': 'on', 'f2': 'on'} + + impressions_called = impression_storage.put.mock_calls[0][1][0] + assert Impression('key', 'f1', 'on', 'some_label', 123, None, 1000) in impressions_called + assert Impression('key', 'f2', 'on', 'some_label', 123, None, 1000) in impressions_called + assert mocker.call('sdk.getTreatments', 5) in telemetry_storage.inc_latency.mock_calls + assert client._logger.mock_calls == [] + assert mocker.call( + Impression('key', 'f1', 'on', 'some_label', 123, None, 1000), + None + ) in client._send_impression_to_listener.mock_calls + assert mocker.call( + Impression('key', 'f2', 'on', 'some_label', 123, None, 1000), + None + ) in client._send_impression_to_listener.mock_calls + + # Test with exception: + split_storage.get_change_number.return_value = -1 + def _raise(*_): + raise Exception('something') + client._evaluator.evaluate_treatment.side_effect = _raise + assert client.get_treatments('key', ['f1', 'f2']) == {'f1': 'control', 'f2': 'control'} + assert len(telemetry_storage.inc_latency.mock_calls) == 2 + + + def test_destroy(self, mocker): + """Test that destroy/destroyed calls are forwarded to the factory.""" + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + telemetry_storage = mocker.Mock(spec=TelemetryStorage) + def _get_storage_mock(name): + return { + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'events': event_storage, + 'telemetry': telemetry_storage + }[name] + factory = mocker.Mock(spec=SplitFactory) + destroyed_mock = mocker.PropertyMock() + type(factory).destroyed = destroyed_mock + + client = Client(factory) + client.destroy() + assert factory.destroy.mock_calls == [mocker.call()] + assert client.destroyed is not None + assert destroyed_mock.mock_calls == [mocker.call()] + + def test_track(self, mocker): + """Test that destroy/destroyed calls are forwarded to the factory.""" + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + event_storage.put.return_value = True + telemetry_storage = mocker.Mock(spec=TelemetryStorage) + def _get_storage_mock(name): + return { + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'events': event_storage, + 'telemetry': telemetry_storage + }[name] + factory = mocker.Mock(spec=SplitFactory) + factory._get_storage = _get_storage_mock + destroyed_mock = mocker.PropertyMock() + destroyed_mock.return_value = False + type(factory).destroyed = destroyed_mock + mocker.patch('splitio.client.client.time.time', new=lambda: 1) + + client = Client(factory) + assert client.track('key', 'user', 'purchase', 12) is True + assert mocker.call([ + Event('key', 'user', 'purchase', 12, 1000) + ]) in event_storage.put.mock_calls diff --git a/tests/client/test_factory.py b/tests/client/test_factory.py new file mode 100644 index 00000000..2030d084 --- /dev/null +++ b/tests/client/test_factory.py @@ -0,0 +1,324 @@ +"""Split factory test module.""" +#pylint: disable=no-self-use,protected-access + +import time +import threading +from splitio.client.factory import get_factory +from splitio.client.config import DEFAULT_CONFIG +from splitio.storage import redis, inmemmory, uwsgi +from splitio.tasks import events_sync, impressions_sync, split_sync, segment_sync, telemetry_sync +from splitio.tasks.util import asynctask, workerpool +from splitio.api.splits import SplitsAPI +from splitio.api.segments import SegmentsAPI +from splitio.api.impressions import ImpressionsAPI +from splitio.api.events import EventsAPI +from splitio.api.telemetry import TelemetryAPI + + +class SplitFactoryTests(object): + """Split factory test cases.""" + + def test_inmemory_client_creation(self, mocker): + """Test that a client with in-memory storage is created correctly.""" + # Setup task mocks + def _split_task_init_mock(self, api, storage, period, event): + self._task = mocker.Mock() + self._api = api + self._storage = storage + self._period = period + self._event = event + event.set() + mocker.patch('splitio.client.factory.SplitSynchronizationTask.__init__', new=_split_task_init_mock) + def _segment_task_init_mock(self, api, storage, split_storage, period, event): + self._task = mocker.Mock() + self._worker_pool = mocker.Mock() + self._api = api + self._segment_storage = storage + self._split_storage = split_storage + self._period = period + self._event = event + event.set() + mocker.patch('splitio.client.factory.SegmentSynchronizationTask.__init__', new=_segment_task_init_mock) + + # Start factory and make assertions + factory = get_factory('some_api_key') + assert isinstance(factory._storages['splits'], inmemmory.InMemorySplitStorage) + assert isinstance(factory._storages['segments'], inmemmory.InMemorySegmentStorage) + assert isinstance(factory._storages['impressions'], inmemmory.InMemoryImpressionStorage) + assert factory._storages['impressions']._impressions.maxsize == 10000 + assert isinstance(factory._storages['events'], inmemmory.InMemoryEventStorage) + assert factory._storages['events']._events.maxsize == 10000 + assert isinstance(factory._storages['telemetry'], inmemmory.InMemoryTelemetryStorage) + + assert isinstance(factory._apis['splits'], SplitsAPI) + assert factory._apis['splits']._client._timeout == 1.5 + assert isinstance(factory._apis['segments'], SegmentsAPI) + assert factory._apis['segments']._client._timeout == 1.5 + assert isinstance(factory._apis['impressions'], ImpressionsAPI) + assert factory._apis['impressions']._client._timeout == 1.5 + assert isinstance(factory._apis['events'], EventsAPI) + assert factory._apis['events']._client._timeout == 1.5 + assert isinstance(factory._apis['telemetry'], TelemetryAPI) + assert factory._apis['telemetry']._client._timeout == 1.5 + + assert isinstance(factory._tasks['splits'], split_sync.SplitSynchronizationTask) + assert factory._tasks['splits']._period == DEFAULT_CONFIG['featuresRefreshRate'] + assert factory._tasks['splits']._storage == factory._storages['splits'] + assert factory._tasks['splits']._api == factory._apis['splits'] + assert isinstance(factory._tasks['segments'], segment_sync.SegmentSynchronizationTask) + assert factory._tasks['segments']._period == DEFAULT_CONFIG['segmentsRefreshRate'] + assert factory._tasks['segments']._segment_storage == factory._storages['segments'] + assert factory._tasks['segments']._split_storage == factory._storages['splits'] + assert factory._tasks['segments']._api == factory._apis['segments'] + assert isinstance(factory._tasks['impressions'], impressions_sync.ImpressionsSyncTask) + assert factory._tasks['impressions']._period == DEFAULT_CONFIG['impressionsRefreshRate'] + assert factory._tasks['impressions']._storage == factory._storages['impressions'] + assert factory._tasks['impressions']._impressions_api == factory._apis['impressions'] + assert isinstance(factory._tasks['events'], events_sync.EventsSyncTask) + assert factory._tasks['events']._period == DEFAULT_CONFIG['eventsPushRate'] + assert factory._tasks['events']._storage == factory._storages['events'] + assert factory._tasks['events']._events_api == factory._apis['events'] + assert isinstance(factory._tasks['telemetry'], telemetry_sync.TelemetrySynchronizationTask) + assert factory._tasks['telemetry']._period == DEFAULT_CONFIG['metricsRefreshRate'] + assert factory._tasks['telemetry']._storage == factory._storages['telemetry'] + assert factory._tasks['telemetry']._api == factory._apis['telemetry'] + assert factory._labels_enabled is True + + def test_redis_client_creation(self, mocker): + """Test that a client with redis storage is created correctly.""" + strict_redis_mock = mocker.Mock() + mocker.patch('splitio.storage.adapters.redis.StrictRedis', new=strict_redis_mock) + + config = { + 'labelsEnabled': False, + 'impressionListener': 123, + 'redisHost': 'some_host', + 'redisPort': 1234, + 'redisDb': 1, + 'redisPassword': 'some_password', + 'redisSocketTimeout': 123, + 'redisSocketConnectTimeout': 123, + 'redisSocketKeepalive': 123, + 'redisSocketKeepaliveOptions': False, + 'redisConnectionPool': False, + 'redisUnixSocketPath': '/some_path', + 'redisEncoding': 'ascii', + 'redisEncodingErrors': 'non-strict', + 'redisCharset': 'ascii', + 'redisErrors':True, + 'redisDecodeResponses': True, + 'redisRetryOnTimeout': True, + 'redisSsl': True, + 'redisSslKeyfile': 'some_file', + 'redisSslCertfile': 'some_cert_file', + 'redisSslCertReqs': 'some_cert_req', + 'redisSslCaCerts': 'some_ca_cert', + 'redisMaxConnections': 999, + } + factory = get_factory('some_api_key', config=config) + assert isinstance(factory._get_storage('splits'), redis.RedisSplitStorage) + assert isinstance(factory._get_storage('segments'), redis.RedisSegmentStorage) + assert isinstance(factory._get_storage('impressions'), redis.RedisImpressionsStorage) + assert isinstance(factory._get_storage('events'), redis.RedisEventsStorage) + assert isinstance(factory._get_storage('telemetry'), redis.RedisTelemetryStorage) + + assert factory._apis == {} + assert factory._tasks == {} + + adapter = factory._get_storage('splits')._redis + assert adapter == factory._get_storage('segments')._redis + assert adapter == factory._get_storage('impressions')._redis + assert adapter == factory._get_storage('events')._redis + assert adapter == factory._get_storage('telemetry')._redis + + assert strict_redis_mock.mock_calls == [mocker.call( + host='some_host', + port=1234, + db=1, + password='some_password', + socket_timeout=123, + socket_connect_timeout=123, + socket_keepalive=123, + socket_keepalive_options=False, + connection_pool=False, + unix_socket_path='/some_path', + encoding='ascii', + encoding_errors='non-strict', + charset='ascii', + errors=True, + decode_responses=True, + retry_on_timeout=True, + ssl=True, + ssl_keyfile='some_file', + ssl_certfile='some_cert_file', + ssl_cert_reqs='some_cert_req', + ssl_ca_certs='some_ca_cert', + max_connections=999 + )] + assert factory._labels_enabled is False + assert factory._impression_listener == 123 + + + def test_uwsgi_client_creation(self): + """Test that a client with redis storage is created correctly.""" + factory = get_factory('some_api_key', config={'uwsgiCache': True, 'impressionListener': 123}) + assert isinstance(factory._get_storage('splits'), uwsgi.UWSGISplitStorage) + assert isinstance(factory._get_storage('segments'), uwsgi.UWSGISegmentStorage) + assert isinstance(factory._get_storage('impressions'), uwsgi.UWSGIImpressionStorage) + assert isinstance(factory._get_storage('events'), uwsgi.UWSGIEventStorage) + assert isinstance(factory._get_storage('telemetry'), uwsgi.UWSGITelemetryStorage) + assert factory._apis == {} + assert factory._tasks == {} + assert factory._labels_enabled is True + assert factory._impression_listener == 123 + + def test_destroy(self, mocker): + """Test that tasks are shutdown and data is flushed when destroy is called.""" + def _split_task_init_mock(self, api, storage, period, event): + self._task = mocker.Mock() + self._api = api + self._storage = storage + self._period = period + self._event = event + event.set() + mocker.patch('splitio.client.factory.SplitSynchronizationTask.__init__', new=_split_task_init_mock) + + def _segment_task_init_mock(self, api, storage, split_storage, period, event): + self._task = mocker.Mock() + self._worker_pool = mocker.Mock() + self._api = api + self._segment_storage = storage + self._split_storage = split_storage + self._period = period + self._event = event + event.set() + mocker.patch('splitio.client.factory.SegmentSynchronizationTask.__init__', new=_segment_task_init_mock) + + imp_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + def _imppression_task_init_mock(self, api, storage, refresh_rate, bulk_size): + self._logger = mocker.Mock() + self._impressions_api = api + self._storage = storage + self._period = refresh_rate + self._task = imp_async_task_mock + self._failed = mocker.Mock() + self._bulk_size = bulk_size + mocker.patch('splitio.client.factory.ImpressionsSyncTask.__init__', new=_imppression_task_init_mock) + + evt_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + def _event_task_init_mock(self, api, storage, refresh_rate, bulk_size): + self._logger = mocker.Mock() + self._impressions_api = api + self._storage = storage + self._period = refresh_rate + self._task = evt_async_task_mock + self._failed = mocker.Mock() + self._bulk_size = bulk_size + mocker.patch('splitio.client.factory.EventsSyncTask.__init__', new=_event_task_init_mock) + + # Start factory and make assertions + factory = get_factory('some_api_key') + + assert factory.destroyed is False + + factory.destroy() + assert imp_async_task_mock.stop.mock_calls == [mocker.call(None)] + assert evt_async_task_mock.stop.mock_calls == [mocker.call(None)] + assert factory.destroyed is True + + def test_destroy_with_event(self, mocker): + """Test that tasks are shutdown and data is flushed when destroy is called.""" + spl_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + def _split_task_init_mock(self, api, storage, period, event): + self._task = spl_async_task_mock + self._api = api + self._storage = storage + self._period = period + self._event = event + event.set() + mocker.patch('splitio.client.factory.SplitSynchronizationTask.__init__', new=_split_task_init_mock) + + sgm_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + worker_pool_mock = mocker.Mock(spec=workerpool.WorkerPool) + def _segment_task_init_mock(self, api, storage, split_storage, period, event): + self._task = sgm_async_task_mock + self._worker_pool = worker_pool_mock + self._api = api + self._segment_storage = storage + self._split_storage = split_storage + self._period = period + self._event = event + event.set() + mocker.patch('splitio.client.factory.SegmentSynchronizationTask.__init__', new=_segment_task_init_mock) + + imp_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + def _imppression_task_init_mock(self, api, storage, refresh_rate, bulk_size): + self._logger = mocker.Mock() + self._impressions_api = api + self._storage = storage + self._period = refresh_rate + self._task = imp_async_task_mock + self._failed = mocker.Mock() + self._bulk_size = bulk_size + mocker.patch('splitio.client.factory.ImpressionsSyncTask.__init__', new=_imppression_task_init_mock) + + evt_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + def _event_task_init_mock(self, api, storage, refresh_rate, bulk_size): + self._logger = mocker.Mock() + self._impressions_api = api + self._storage = storage + self._period = refresh_rate + self._task = evt_async_task_mock + self._failed = mocker.Mock() + self._bulk_size = bulk_size + mocker.patch('splitio.client.factory.EventsSyncTask.__init__', new=_event_task_init_mock) + + tmt_async_task_mock = mocker.Mock(spec=asynctask.AsyncTask) + def _telemetry_task_init_mock(self, api, storage, refresh_rate): + self._task = tmt_async_task_mock + self._logger = mocker.Mock() + self._api = api + self._storage = storage + self._period = refresh_rate + mocker.patch('splitio.client.factory.TelemetrySynchronizationTask.__init__', new=_telemetry_task_init_mock) + + # Start factory and make assertions + factory = get_factory('some_api_key') + + assert factory.destroyed is False + + event = threading.Event() + factory.destroy(event) + + # When destroy is called an event is created and passed to each task when + # stop() is called. We will extract those events assert their type, and assert that + # by setting them, the main event gets set. + splits_event = spl_async_task_mock.stop.mock_calls[0][1][0] + segments_event = worker_pool_mock.stop.mock_calls[0][1][0] # Segment task stops when wp finishes. + impressions_event = imp_async_task_mock.stop.mock_calls[0][1][0] + events_event = evt_async_task_mock.stop.mock_calls[0][1][0] + telemetry_event = tmt_async_task_mock.stop.mock_calls[0][1][0] + + # python2 & 3 compatibility + try: + from threading import _Event as __EVENT_CLASS + except ImportError: + from threading import Event as __EVENT_CLASS + + assert isinstance(splits_event, __EVENT_CLASS) + assert isinstance(segments_event, __EVENT_CLASS) + assert isinstance(impressions_event, __EVENT_CLASS) + assert isinstance(events_event, __EVENT_CLASS) + assert isinstance(telemetry_event, __EVENT_CLASS) + assert not event.is_set() + + splits_event.set() + segments_event.set() + impressions_event.set() + events_event.set() + telemetry_event.set() + + time.sleep(1) # I/O wait to trigger context switch, to give the waiting thread + # a chance to run and set the main event. + + assert event.is_set() diff --git a/tests/client/test_input_validator.py b/tests/client/test_input_validator.py new file mode 100644 index 00000000..fe64c2a8 --- /dev/null +++ b/tests/client/test_input_validator.py @@ -0,0 +1,597 @@ +"""Unit tests for the input_validator module.""" +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from splitio.client.factory import SplitFactory +from splitio.client.client import CONTROL, Client +from splitio.client.manager import SplitManager +from splitio.client.key import Key +from splitio.storage import SplitStorage, EventStorage, ImpressionStorage, TelemetryStorage, \ + SegmentStorage +from splitio.models.splits import Split, SplitView +from splitio.models.grammar.condition import Condition +from splitio.models.grammar.partitions import Partition +from splitio.client import input_validator + + +class ClientInputValidationTests(object): + """Input validation test cases.""" + + def test_get_treatment(self, mocker): + """Test get_treatment validation.""" + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + storage_mock = mocker.Mock(spec=SplitStorage) + storage_mock.get.return_value = split_mock + + def _get_storage_mock(storage): + return { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + 'telemetry': mocker.Mock(spec=TelemetryStorage) + }[storage] + factory_mock = mocker.Mock(spec=SplitFactory) + factory_mock._get_storage.side_effect = _get_storage_mock + factory_destroyed = mocker.PropertyMock() + factory_destroyed.return_value = False + type(factory_mock).destroyed = factory_destroyed + + client = Client(factory_mock) + client._logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) + + assert client.get_treatment(None, 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed a null key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('', 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an empty key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + key = ''.join('a' for _ in range(0,255)) + assert client.get_treatment(key, 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: key too long - must be 250 characters or less.') + ] + + client._logger.reset_mock() + assert client.get_treatment(12345, 'some_feature') == 'default_treatment' + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment: key 12345 is not of type string, converting.') + ] + + client._logger.reset_mock() + assert client.get_treatment(float('nan'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(float('inf'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(True, 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment([], 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('some_key', None) == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed a null feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('some_key', 123) == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('some_key', True) == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('some_key', []) == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('some_key', '') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an empty feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment('some_key', 'some_feature') == 'default_treatment' + assert client._logger.error.mock_calls == [] + assert client._logger.warning.mock_calls == [] + + client._logger.reset_mock() + assert client.get_treatment(Key(None, 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed a null matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key('', 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an empty matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key(float('nan'), 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key(float('inf'), 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key(True, 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key([], 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key(12345, 'bucketing_key'), 'some_feature') == 'default_treatment' + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment: matching_key 12345 is not of type string, ' 'converting.') + ] + + client._logger.reset_mock() + key = ''.join('a' for _ in range(0,255)) + assert client.get_treatment(Key(key, 'bucketing_key'), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: matching_key too long - must be 250 characters or less.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key('mathcing_key', None), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed a null bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key('mathcing_key', True), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key('mathcing_key', []), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an invalid bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key('mathcing_key', ''), 'some_feature') == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: you passed an empty bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment(Key('mathcing_key', 12345), 'some_feature') == 'default_treatment' + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment: bucketing_key 12345 is not of type string, converting.') + ] + + client._logger.reset_mock() + assert client.get_treatment('mathcing_key', 'some_feature', True) == CONTROL + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment: attributes must be of type dictionary.') + ] + + client._logger.reset_mock() + assert client.get_treatment('mathcing_key', 'some_feature', {'test': 'test'}) =='default_treatment' + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.get_treatment('mathcing_key', 'some_feature', None) =='default_treatment' + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.get_treatment('mathcing_key', ' some_feature ', None) =='default_treatment' + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment: feature_name \' some_feature \' has extra whitespace, trimming.') + ] + + def test_track(self, mocker): + """Test track method().""" + events_storage_mock = mocker.Mock(spec=EventStorage) + events_storage_mock.put.return_value = True + factory_mock = mocker.Mock(spec=SplitFactory) + factory_destroyed = mocker.PropertyMock() + factory_destroyed.return_value = False + type(factory_mock).destroyed = factory_destroyed + + client = Client(factory_mock) + client._events_storage = mocker.Mock(spec=EventStorage) + client._events_storage.put.return_value = True + client._logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) + + assert client.track(None, "traffic_type", "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed a null key, key must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("", "traffic_type", "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an empty key, key must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track(12345, "traffic_type", "event_type", 1) == True + assert client._logger.warning.mock_calls == [ + mocker.call("track: key 12345 is not of type string, converting.") + ] + + client._logger.reset_mock() + assert client.track(True, "traffic_type", "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid key, key must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track([], "traffic_type", "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid key, key must be a non-empty string.") + ] + + client._logger.reset_mock() + key = ''.join('a' for _ in range(0,255)) + assert client.track(key, "traffic_type", "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: key too long - must be 250 characters or less.") + ] + + client._logger.reset_mock() + assert client.track("some_key", None, "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed a null traffic_type, traffic_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "", "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an empty traffic_type, traffic_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", 12345, "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", True, "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", [], "event_type", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "TRAFFIC_type", "event_type", 1) == True + assert client._logger.warning.mock_calls == [ + mocker.call("track: TRAFFIC_type should be all lowercase - converting string to lowercase.") + ] + + assert client.track("some_key", "traffic_type", None, 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed a null event_type, event_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an empty event_type, event_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", True, 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", [], 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", 12345, 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "@@", 1) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: you passed @@, event_type must adhere to the regular " + "expression ^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$. This means " + "an event name must be alphanumeric, cannot be more than 80 " + "characters long, and can only include a dash, underscore, " + "period, or colon as separators of alphanumeric characters.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "event_type", None) == True + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "event_type", 1) == True + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "event_type", 1.23) == True + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "event_type", "test") == False + assert client._logger.error.mock_calls == [ + mocker.call("track: value must be a number.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "event_type", True) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: value must be a number.") + ] + + client._logger.reset_mock() + assert client.track("some_key", "traffic_type", "event_type", []) == False + assert client._logger.error.mock_calls == [ + mocker.call("track: value must be a number.") + ] + + def test_get_treatments(self, mocker): + """Test getTreatments() method.""" + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + + storage_mock = mocker.Mock(spec=SplitStorage) + storage_mock.get.return_value = split_mock + + factory_mock = mocker.Mock(spec=SplitFactory) + factory_mock._get_storage.return_value = storage_mock + factory_destroyed = mocker.PropertyMock() + factory_destroyed.return_value = False + type(factory_mock).destroyed = factory_destroyed + + client = Client(factory_mock) + client._logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) + + assert client.get_treatments(None, ['some_feature']) == {'some_feature': CONTROL} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: you passed a null key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatments("", ['some_feature']) == {'some_feature': CONTROL} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: you passed an empty key, key must be a non-empty string.') + ] + + key = ''.join('a' for _ in range(0,255)) + client._logger.reset_mock() + assert client.get_treatments(key, ['some_feature']) == {'some_feature': CONTROL} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: key too long - must be 250 characters or less.') + ] + + client._logger.reset_mock() + assert client.get_treatments(12345, ['some_feature']) == {'some_feature': 'default_treatment'} + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatments: key 12345 is not of type string, converting.') + ] + + client._logger.reset_mock() + assert client.get_treatments(True, ['some_feature']) == {'some_feature': CONTROL} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatments([], ['some_feature']) == {'some_feature': CONTROL} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatments('some_key', None) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments('some_key', True) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments('some_key', 'some_string') == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments('some_key', []) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments('some_key', [None, None]) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments('some_key', [True]) == {} + assert mocker.call('get_treatments: feature_names must be a non-empty array.') in client._logger.error.mock_calls + + client._logger.reset_mock() + assert client.get_treatments('some_key', ['', '']) == {} + assert mocker.call('get_treatments: feature_names must be a non-empty array.') in client._logger.error.mock_calls + + client._logger.reset_mock() + assert client.get_treatments('some_key', ['some ']) == {'some': 'default_treatment'} + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatments: feature_name \'some \' has extra whitespace, trimming.') + ] + + +class ManagerInputValidationTests(object): + """Manager input validation test cases.""" + + def test_split_(self, mocker): + """Test split input validation.""" + storage_mock = mocker.Mock(spec=SplitStorage) + split_mock = mocker.Mock(spec=Split) + storage_mock.get.return_value = split_mock + factory_mock = mocker.Mock(spec=SplitFactory) + factory_mock._get_storage.return_value = storage_mock + factory_destroyed = mocker.PropertyMock() + factory_destroyed.return_value = False + type(factory_mock).destroyed = factory_destroyed + + manager = SplitManager(factory_mock) + manager._logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=manager._logger) + + assert manager.split(None) == None + assert manager._logger.error.mock_calls == [ + mocker.call("split: you passed a null feature_name, feature_name must be a non-empty string.") + ] + + manager._logger.reset_mock() + assert manager.split("") == None + assert manager._logger.error.mock_calls == [ + mocker.call("split: you passed an empty feature_name, feature_name must be a non-empty string.") + ] + + manager._logger.reset_mock() + assert manager.split(True) == None + assert manager._logger.error.mock_calls == [ + mocker.call("split: you passed an invalid feature_name, feature_name must be a non-empty string.") + ] + + manager._logger.reset_mock() + assert manager.split([]) == None + assert manager._logger.error.mock_calls == [ + mocker.call("split: you passed an invalid feature_name, feature_name must be a non-empty string.") + ] + + manager._logger.reset_mock() + manager.split('some_split') + assert split_mock.to_split_view.mock_calls == [mocker.call()] + assert manager._logger.error.mock_calls == [] + + + +#class TestInputSanitizationFactory(TestCase): +# +# def setUp(self): +# input_validator._LOGGER.error = mock.MagicMock() +# self.logger_error = input_validator._LOGGER.error +# +# def test_factory_with_null_apikey(self): +# self.assertEqual(None, get_factory(None)) +# self.logger_error \ +# .assert_called_once_with("factory_instantiation: you passed a null apikey, apikey" + +# " must be a non-empty string.") +# +# def test_factory_with_empty_apikey(self): +# self.assertEqual(None, get_factory('')) +# self.logger_error \ +# .assert_called_once_with("factory_instantiation: you passed an empty apikey, apikey" + +# " must be a non-empty string.") +# +# def test_factory_with_invalid_apikey(self): +# self.assertEqual(None, get_factory(True)) +# self.logger_error \ +# .assert_called_once_with("factory_instantiation: you passed an invalid apikey, apikey" + +# " must be a non-empty string.") +# +# def test_factory_with_invalid_apikey_redis(self): +# config = { +# 'redisDb': 0, +# 'redisHost': 'localhost' +# } +# self.assertNotEqual(None, get_factory(True, config=config)) +# self.logger_error.assert_not_called() +# +# def test_factory_with_invalid_config(self): +# config = { +# 'some': 0 +# } +# self.assertEqual(None, get_factory("apikey", config=config)) +# self.logger_error \ +# .assert_called_once_with('no ready parameter has been set - incorrect control ' +# + 'treatments could be logged') +# +# def test_factory_with_invalid_null_ready(self): +# config = { +# 'ready': None +# } +# self.assertEqual(None, get_factory("apikey", config=config)) +# self.logger_error \ +# .assert_called_once_with('no ready parameter has been set - incorrect control ' +# + 'treatments could be logged') +# +# def test_factory_with_invalid_ready(self): +# config = { +# 'ready': True +# } +# self.assertEqual(None, get_factory("apikey", config=config)) +# self.logger_error \ +# .assert_called_once_with('no ready parameter has been set - incorrect control ' +# + 'treatments could be logged') diff --git a/splitio/tests/murmur3-custom-uuids.csv b/tests/engine/files/murmur3-custom-uuids.csv similarity index 100% rename from splitio/tests/murmur3-custom-uuids.csv rename to tests/engine/files/murmur3-custom-uuids.csv diff --git a/splitio/tests/murmur3-sample-data-non-alpha-numeric-v2.csv b/tests/engine/files/murmur3-sample-data-non-alpha-numeric-v2.csv similarity index 100% rename from splitio/tests/murmur3-sample-data-non-alpha-numeric-v2.csv rename to tests/engine/files/murmur3-sample-data-non-alpha-numeric-v2.csv diff --git a/splitio/tests/murmur3-sample-data-v2.csv b/tests/engine/files/murmur3-sample-data-v2.csv similarity index 100% rename from splitio/tests/murmur3-sample-data-v2.csv rename to tests/engine/files/murmur3-sample-data-v2.csv diff --git a/splitio/tests/sample-data-non-alpha-numeric.jsonl b/tests/engine/files/sample-data-non-alpha-numeric.jsonl similarity index 100% rename from splitio/tests/sample-data-non-alpha-numeric.jsonl rename to tests/engine/files/sample-data-non-alpha-numeric.jsonl diff --git a/splitio/tests/sample-data.jsonl b/tests/engine/files/sample-data.jsonl similarity index 100% rename from splitio/tests/sample-data.jsonl rename to tests/engine/files/sample-data.jsonl diff --git a/tests/engine/test_evaluator.py b/tests/engine/test_evaluator.py new file mode 100644 index 00000000..4c5e936b --- /dev/null +++ b/tests/engine/test_evaluator.py @@ -0,0 +1,122 @@ +"""Evaluator tests module.""" +import logging + +from splitio.models.splits import Split +from splitio.models.grammar.condition import Condition, ConditionType +from splitio.models.impressions import Label +from splitio.engine import evaluator, splitters +from splitio.storage import SplitStorage, SegmentStorage + +class EvaluatorTests(object): + """Test evaluator behavior.""" + + def _build_evaluator_with_mocks(self, mocker): + """Build an evaluator with mocked dependencies.""" + split_storage_mock = mocker.Mock(spec=SplitStorage) + splitter_mock = mocker.Mock(spec=splitters.Splitter) + segment_storage_mock = mocker.Mock(spec=SegmentStorage) + logger_mock = mocker.Mock(spec=logging.Logger) + e = evaluator.Evaluator(split_storage_mock, segment_storage_mock, splitter_mock) + e._logger = logger_mock + return e + + def test_evaluate_treatment_missing_split(self, mocker): + """Test that a missing split logs and returns CONTROL.""" + e = self._build_evaluator_with_mocks(mocker) + e._split_storage.get.return_value = None + result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + assert result['treatment'] == evaluator.CONTROL + assert result['impression']['change_number'] == -1 + assert result['impression']['label'] == Label.SPLIT_NOT_FOUND + + def test_evaluate_treatment_killed_split(self, mocker): + """Test that a killed split returns the default treatment.""" + e = self._build_evaluator_with_mocks(mocker) + mocked_split = mocker.Mock(spec=Split) + mocked_split.default_treatment = 'off' + mocked_split.killed = True + mocked_split.change_number = 123 + e._split_storage.get.return_value = mocked_split + result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + assert result['treatment'] == 'off' + assert result['impression']['change_number'] == 123 + assert result['impression']['label'] == Label.KILLED + + def test_evaluate_treatment_ok(self, mocker): + """Test that a non-killed split returns the appropriate treatment.""" + e = self._build_evaluator_with_mocks(mocker) + e._get_treatment_for_split = mocker.Mock() + e._get_treatment_for_split.return_value = ('on', 'some_label') + mocked_split = mocker.Mock(spec=Split) + mocked_split.default_treatment = 'off' + mocked_split.killed = False + mocked_split.change_number = 123 + e._split_storage.get.return_value = mocked_split + result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + assert result['treatment'] == 'on' + assert result['impression']['change_number'] == 123 + assert result['impression']['label'] == 'some_label' + + def test_evaluate_treatment_ok(self, mocker): + """Test that a killed split returns the default treatment.""" + e = self._build_evaluator_with_mocks(mocker) + e._get_treatment_for_split = mocker.Mock() + e._get_treatment_for_split.return_value = ('on', 'some_label') + mocked_split = mocker.Mock(spec=Split) + mocked_split.default_treatment = 'off' + mocked_split.killed = False + mocked_split.change_number = 123 + e._split_storage.get.return_value = mocked_split + result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + assert result['treatment'] == 'on' + assert result['impression']['change_number'] == 123 + assert result['impression']['label'] == 'some_label' + + def test_get_gtreatment_for_split_no_condition_matches(self, mocker): + """Test no condition matches.""" + e = self._build_evaluator_with_mocks(mocker) + e._splitter.get_treatment.return_value = 'on' + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + mocked_split = mocker.Mock(spec=Split) + mocked_split.killed = False + type(mocked_split).conditions = conditions_mock + treatment, label = e._get_treatment_for_split(mocked_split, 'some_key', 'some_bucketing', {'attr1': 1}) + assert treatment == None + assert label == None + + def test_get_gtreatment_for_split_non_rollout(self, mocker): + """Test condition matches.""" + e = self._build_evaluator_with_mocks(mocker) + e._splitter.get_treatment.return_value = 'on' + mocked_condition_1 = mocker.Mock(spec=Condition) + mocked_condition_1.condition_type = ConditionType.WHITELIST + mocked_condition_1.label = 'some_label' + mocked_condition_1.matches.return_value = True + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [mocked_condition_1] + mocked_split = mocker.Mock(spec=Split) + mocked_split.killed = False + type(mocked_split).conditions = conditions_mock + treatment, label = e._get_treatment_for_split(mocked_split, 'some_key', 'some_bucketing', {'attr1': 1}) + assert treatment == 'on' + assert label == 'some_label' + + def test_get_gtreatment_for_split_rollout(self, mocker): + """Test rollout condition returns default treatment.""" + e = self._build_evaluator_with_mocks(mocker) + e._splitter.get_bucket.return_value = 60 + mocked_condition_1 = mocker.Mock(spec=Condition) + mocked_condition_1.condition_type = ConditionType.ROLLOUT + mocked_condition_1.label = 'some_label' + mocked_condition_1.matches.return_value = True + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [mocked_condition_1] + mocked_split = mocker.Mock(spec=Split) + mocked_split.traffic_allocation = 50 + mocked_split.default_treatment = 'almost-on' + mocked_split.killed = False + type(mocked_split).conditions = conditions_mock + treatment, label = e._get_treatment_for_split(mocked_split, 'some_key', 'some_bucketing', {'attr1': 1}) + assert treatment == 'almost-on' + assert label == Label.NOT_IN_SPLIT diff --git a/tests/engine/test_hashfns.py b/tests/engine/test_hashfns.py new file mode 100644 index 00000000..ce8ffd86 --- /dev/null +++ b/tests/engine/test_hashfns.py @@ -0,0 +1,107 @@ +"""Hash function test module.""" +#pylint: disable=no-self-use,protected-access +import io +import json +import os +import sys + +import pytest +from splitio.engine import hashfns, splitters +from splitio.models import splits + + +class HashFunctionsTests(object): + """Hash functions test cases.""" + + def test_get_hash_function(self): + """Test that the correct hash function is returned.""" + assert hashfns.get_hash_fn(splits.HashAlgorithm.LEGACY) == hashfns.legacy.legacy_hash + assert hashfns.get_hash_fn(splits.HashAlgorithm.MURMUR) == hashfns._murmur_hash + + def test_legacy_hash_ascii_data(self): + """Test legacy hash function against known results.""" + splitter = splitters.Splitter() + file_name = os.path.join(os.path.dirname(__file__), 'files', 'sample-data.jsonl') + with open(file_name, 'r') as flo: + lines = flo.read().split('\n') + + for line in lines: + if line is None or line == '': + continue + seed, key, hashed, bucket = json.loads(line) + assert hashfns.legacy.legacy_hash(key, seed) == hashed + assert splitter.get_bucket(key, seed, splits.HashAlgorithm.LEGACY) == bucket + + @pytest.mark.skipif(sys.version_info > (3, 0), reason='Should skip this on python3.') + def test_legacy_hash_non_ascii_data(self): + """Test legacy hash function against known results.""" + splitter = splitters.Splitter() + file_name = os.path.join( + os.path.dirname(__file__), + 'files', + 'sample-data-non-alpha-numeric.jsonl' + ) + with open(file_name, 'r') as flo: + lines = flo.read().split('\n') + + for line in lines: + if line is None or line == '': + continue + seed, key, hashed, bucket = json.loads(line) + assert hashfns.legacy.legacy_hash(key, seed) == hashed + assert splitter.get_bucket(key, seed, splits.HashAlgorithm.LEGACY) == bucket + + def test_murmur_hash_ascii_data(self): + """Test legacy hash function against known results.""" + splitter = splitters.Splitter() + file_name = os.path.join(os.path.dirname(__file__), 'files', 'murmur3-sample-data-v2.csv') + with open(file_name, 'r') as flo: + lines = flo.read().split('\n') + + for line in lines: + if line is None or line == '': + continue + seed, key, hashed, bucket = line.split(',') + seed = int(seed) + bucket = int(bucket) + hashed = int(hashed) + assert hashfns._murmur_hash(key, seed) == hashed + assert splitter.get_bucket(key, seed, splits.HashAlgorithm.MURMUR) == bucket + + def test_murmur_more_ascii_data(self): + """Test legacy hash function against known results.""" + splitter = splitters.Splitter() + file_name = os.path.join(os.path.dirname(__file__), 'files', 'murmur3-custom-uuids.csv') + with open(file_name, 'r') as flo: + lines = flo.read().split('\n') + + for line in lines: + if line is None or line == '': + continue + seed, key, hashed, bucket = line.split(',') + seed = int(seed) + bucket = int(bucket) + hashed = int(hashed) + assert hashfns._murmur_hash(key, seed) == hashed + assert splitter.get_bucket(key, seed, splits.HashAlgorithm.MURMUR) == bucket + + def test_murmur_hash_non_ascii_data(self): + """Test legacy hash function against known results.""" + splitter = splitters.Splitter() + file_name = os.path.join( + os.path.dirname(__file__), + 'files', + 'murmur3-sample-data-non-alpha-numeric-v2.csv' + ) + with io.open(file_name, 'r', encoding='utf-8') as flo: + lines = flo.read().split('\n') + + for line in lines: + if line is None or line == '': + continue + seed, key, hashed, bucket = line.split(',') + seed = int(seed) + bucket = int(bucket) + hashed = int(hashed) + assert hashfns._murmur_hash(key, seed) == hashed + assert splitter.get_bucket(key, seed, splits.HashAlgorithm.MURMUR) == bucket diff --git a/tests/engine/test_splitter.py b/tests/engine/test_splitter.py new file mode 100644 index 00000000..d4dc2f3b --- /dev/null +++ b/tests/engine/test_splitter.py @@ -0,0 +1,51 @@ +"""Splitter test module.""" + +from splitio.models.grammar.partitions import Partition +from splitio.engine.splitters import Splitter, CONTROL + + +class SplitterTests(object): + """Tests for engine/splitter.""" + + def test_get_treatment(self, mocker): + """Test get_treatment method on all possible outputs.""" + splitter = Splitter() + + # no partitions returns control + assert splitter.get_treatment('key', 123, [], 1) == CONTROL + # single partition returns that treatment + assert splitter.get_treatment('key', 123, [Partition('on', 100)], 1) == 'on' + # multiple partitions call hash_functions + splitter.get_treatment_for_bucket = lambda x,y: 'on' + partitions = [Partition('on', 50), Partition('off', 50)] + assert splitter.get_treatment('key', 123, partitions, 1) == 'on' + + def test_get_bucket(self, mocker): + """Test get_bucket method.""" + get_hash_fn_mock = mocker.Mock() + hash_fn = mocker.Mock() + hash_fn.return_value = 1 + get_hash_fn_mock.side_effect = lambda x: hash_fn + mocker.patch('splitio.engine.splitters.get_hash_fn', new=get_hash_fn_mock) + splitter = Splitter() + splitter.get_bucket(1, 123, 1) + assert get_hash_fn_mock.mock_calls == [mocker.call(1)] + assert hash_fn.mock_calls == [mocker.call(1, 123)] + + def test_treatment_for_bucket(self, mocker): + """Test treatment for bucket method.""" + splitter = Splitter() + assert splitter.get_treatment_for_bucket(0, []) == CONTROL + assert splitter.get_treatment_for_bucket(-1, []) == CONTROL + assert splitter.get_treatment_for_bucket(101, [Partition('a', 100)]) == CONTROL + assert splitter.get_treatment_for_bucket(1, [Partition('a', 100)]) == 'a' + assert splitter.get_treatment_for_bucket(100, [Partition('a', 100)]) == 'a' + assert splitter.get_treatment_for_bucket(50, [Partition('a', 50), Partition('b', 50)]) == 'a' + assert splitter.get_treatment_for_bucket(51, [Partition('a', 50), Partition('b', 50)]) == 'b' + + + + + + + diff --git a/splitio/tests/regex.txt b/tests/models/grammar/files/regex.txt similarity index 100% rename from splitio/tests/regex.txt rename to tests/models/grammar/files/regex.txt diff --git a/tests/models/grammar/test_conditions.py b/tests/models/grammar/test_conditions.py new file mode 100644 index 00000000..abdcf0a1 --- /dev/null +++ b/tests/models/grammar/test_conditions.py @@ -0,0 +1,78 @@ +"""Condition model tests module.""" + +from splitio.models.grammar import condition +from splitio.models.grammar import partitions +from splitio.models.grammar import matchers + +class ConditionTests(object): + """Test the condition object model.""" + + raw = { + 'partitions': [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'ALL_KEYS', + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + + def test_parse(self): + """Test parsing from raw dict.""" + parsed = condition.from_raw(self.raw) + assert isinstance(parsed, condition.Condition) + assert parsed.label == 'some_label' + assert parsed.condition_type == condition.ConditionType.WHITELIST + assert isinstance(parsed.matchers[0], matchers.AllKeysMatcher) + assert isinstance(parsed.partitions[0], partitions.Partition) + assert parsed.partitions[0].treatment == 'on' + assert parsed.partitions[0].size == 50 + assert parsed.partitions[1].treatment == 'off' + assert parsed.partitions[1].size == 50 + assert parsed._combiner == condition._MATCHER_COMBINERS['AND'] + + def test_segment_names(self, mocker): + """Test fetching segment_names.""" + matcher1 = mocker.Mock(spec=matchers.UserDefinedSegmentMatcher) + matcher2 = mocker.Mock(spec=matchers.UserDefinedSegmentMatcher) + matcher1._segment_name = 'segment1' + matcher2._segment_name = 'segment2' + cond = condition.Condition([matcher1, matcher2], condition._MATCHER_COMBINERS['AND'], [], 'some_label') + assert cond.get_segment_names() == ['segment1', 'segment2'] + + def test_to_json(self): + """Test JSON serialization of a condition.""" + as_json = condition.from_raw(self.raw).to_json() + assert as_json['partitions'] == [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ] + assert as_json['conditionType'] == 'WHITELIST' + assert as_json['label'] == 'some_label' + assert as_json['matcherGroup']['matchers'][0]['matcherType'] == 'ALL_KEYS' + assert as_json['matcherGroup']['matchers'][0]['negate'] == False + assert as_json['matcherGroup']['combiner'] == 'AND' + + def test_matches(self, mocker): + """Test that matches works properly.""" + matcher1_mock = mocker.Mock(spec=matchers.base.Matcher) + matcher2_mock = mocker.Mock(spec=matchers.base.Matcher) + matcher1_mock.evaluate.return_value = True + matcher2_mock.evaluate.return_value = True + cond = condition.Condition( + [matcher1_mock, matcher2_mock], + condition._MATCHER_COMBINERS['AND'], + [partitions.Partition('on', 50), partitions.Partition('off', 50)], + 'some_label' + ) + assert cond.matches('some_key', {'a': 1}, {'some_context_option': 0}) == True + assert matcher1_mock.evaluate.mock_calls == [mocker.call('some_key', {'a': 1}, {'some_context_option': 0})] + assert matcher2_mock.evaluate.mock_calls == [mocker.call('some_key', {'a': 1}, {'some_context_option': 0})] diff --git a/tests/models/grammar/test_matchers.py b/tests/models/grammar/test_matchers.py new file mode 100644 index 00000000..a0ad3197 --- /dev/null +++ b/tests/models/grammar/test_matchers.py @@ -0,0 +1,886 @@ +"""Matchers tests module.""" +#pylint: disable=protected-access,line-too-long,unsubscriptable-object + +import abc +import calendar +import json +import os.path +import re + +from datetime import datetime + +from splitio.models.grammar import matchers +from splitio.storage import SegmentStorage +from splitio.engine.evaluator import Evaluator + + +class MatcherTestsBase(object): + """Abstract class to make sure we test all relevant methods.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + pass + + @abc.abstractmethod + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + pass + + @abc.abstractmethod + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + pass + + +class AllKeysMatcherTests(MatcherTestsBase): + """Test AllKeys matcher methods.""" + + raw = { + 'matcherType': "ALL_KEYS", + 'negate': False + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.AllKeysMatcher) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.AllKeysMatcher(self.raw) + assert matcher.evaluate(None) is False + assert matcher.evaluate('asd') is True + assert matcher.evaluate('asd', {'a': 1}, {}) is True + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.AllKeysMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'ALL_KEYS' + + +class BetweenMatcherTests(MatcherTestsBase): + """Test in between matcher behaviour.""" + + raw_number = { + 'matcherType': 'BETWEEN', + 'negate': False, + 'betweenMatcherData': { + 'start': 1, + 'end': 3, + 'dataType': 'NUMBER' + } + } + + raw_date = { + 'matcherType': 'BETWEEN', + 'negate': False, + 'betweenMatcherData': { + 'start': int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000, + 'end': int(calendar.timegm((datetime(2019, 12, 23, 9, 30, 45)).timetuple())) * 1000, + 'dataType': 'DATETIME' + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed_number = matchers.from_raw(self.raw_number) + assert isinstance(parsed_number, matchers.BetweenMatcher) + assert parsed_number._data_type == 'NUMBER' + assert parsed_number._negate is False + assert parsed_number._original_lower == 1 + assert parsed_number._original_upper == 3 + assert parsed_number._lower == 1 + assert parsed_number._upper == 3 + + parsed_date = matchers.from_raw(self.raw_date) + assert isinstance(parsed_number, matchers.BetweenMatcher) + assert parsed_date._data_type == 'DATETIME' + assert parsed_date._original_lower == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000 + assert parsed_date._original_upper == int(calendar.timegm((datetime(2019, 12, 23, 9, 30, 45)).timetuple())) * 1000 + assert parsed_date._lower == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 0)).timetuple())) + assert parsed_date._upper == int(calendar.timegm((datetime(2019, 12, 23, 9, 30, 0)).timetuple())) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed_number = matchers.BetweenMatcher(self.raw_number) + assert parsed_number.evaluate(0) is False + assert parsed_number.evaluate(1) is True + assert parsed_number.evaluate(2) is True + assert parsed_number.evaluate(3) is True + assert parsed_number.evaluate(4) is False + assert parsed_number.evaluate('a') is False + assert parsed_number.evaluate([]) is False + assert parsed_number.evaluate({}) is False + assert parsed_number.evaluate(True) is False + assert parsed_number.evaluate(object()) is False + + + parsed_date = matchers.BetweenMatcher(self.raw_date) + assert parsed_date.evaluate(int(calendar.timegm((datetime(2019, 12, 20, 9, 30)).timetuple()))) is False + assert parsed_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple()))) is True + assert parsed_date.evaluate(int(calendar.timegm((datetime(2019, 12, 22, 9, 30)).timetuple()))) is True + assert parsed_date.evaluate(int(calendar.timegm((datetime(2019, 12, 23, 9, 30, 45)).timetuple()))) is True + assert parsed_date.evaluate(int(calendar.timegm((datetime(2019, 12, 24, 9, 30)).timetuple()))) is False + assert parsed_date.evaluate('a') is False + assert parsed_date.evaluate([]) is False + assert parsed_date.evaluate({}) is False + assert parsed_date.evaluate(True) is False + assert parsed_date.evaluate(object()) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json_number = matchers.BetweenMatcher(self.raw_number).to_json() + assert as_json_number['betweenMatcherData']['start'] == 1 + assert as_json_number['betweenMatcherData']['end'] == 3 + assert as_json_number['betweenMatcherData']['dataType'] == 'NUMBER' + + as_json_date = matchers.BetweenMatcher(self.raw_date).to_json() + assert as_json_date['betweenMatcherData']['start'] == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple()) * 1000) + assert as_json_date['betweenMatcherData']['end'] == int(calendar.timegm((datetime(2019, 12, 23, 9, 30, 45)).timetuple()) * 1000) + assert as_json_date['betweenMatcherData']['dataType'] == 'DATETIME' + + +class EqualToMatcherTests(MatcherTestsBase): + """Test equal to matcher.""" + + raw_number = { + 'matcherType': 'EQUAL_TO', + 'negate': False, + 'unaryNumericMatcherData': { + 'value': 5, + 'dataType': 'NUMBER' + } + } + + raw_date = { + 'matcherType': 'EQUAL_TO', + 'negate': False, + 'unaryNumericMatcherData': { + 'value': int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000, + 'dataType': 'DATETIME' + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed_number = matchers.from_raw(self.raw_number) + assert isinstance(parsed_number, matchers.EqualToMatcher) + assert parsed_number._data_type == 'NUMBER' + assert parsed_number._original_value == 5 + assert parsed_number._value == 5 + + parsed_date = matchers.from_raw(self.raw_date) + assert isinstance(parsed_date, matchers.EqualToMatcher) + assert parsed_date._data_type == 'DATETIME' + assert parsed_date._original_value == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45, 0)).timetuple())) * 1000 + assert parsed_date._value == int(calendar.timegm((datetime(2019, 12, 21, 0, 0, 0)).timetuple())) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher_number = matchers.EqualToMatcher(self.raw_number) + assert matcher_number.evaluate(4) is False + assert matcher_number.evaluate(5) is True + assert matcher_number.evaluate(6) is False + assert matcher_number.evaluate('a') is False + assert matcher_number.evaluate([]) is False + assert matcher_number.evaluate({}) is False + assert matcher_number.evaluate(True) is False + assert matcher_number.evaluate(object()) is False + + matcher_date = matchers.EqualToMatcher(self.raw_date) + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 0, 0, 0)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 20, 0, 0, 0)).timetuple()))) is False + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 22, 0, 0, 0)).timetuple()))) is False + assert matcher_date.evaluate('a') is False + assert matcher_date.evaluate([]) is False + assert matcher_date.evaluate({}) is False + assert matcher_date.evaluate(True) is False + assert matcher_date.evaluate(object()) is False + + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json_number = matchers.from_raw(self.raw_number).to_json() + assert as_json_number['unaryNumericMatcherData']['dataType'] == 'NUMBER' + assert as_json_number['unaryNumericMatcherData']['value'] == 5 + assert as_json_number['matcherType'] == 'EQUAL_TO' + assert as_json_number['negate'] is False + + as_json_number = matchers.from_raw(self.raw_date).to_json() + assert as_json_number['unaryNumericMatcherData']['dataType'] == 'DATETIME' + assert as_json_number['unaryNumericMatcherData']['value'] == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000 + assert as_json_number['matcherType'] == 'EQUAL_TO' + assert as_json_number['negate'] is False + + +class GreaterOrEqualMatcherTests(MatcherTestsBase): + """Test greater or equal matcher.""" + + raw_number = { + 'matcherType': 'GREATER_THAN_OR_EQUAL_TO', + 'negate': False, + 'unaryNumericMatcherData': { + 'value': 5, + 'dataType': 'NUMBER' + } + } + + raw_date = { + 'matcherType': 'GREATER_THAN_OR_EQUAL_TO', + 'negate': False, + 'unaryNumericMatcherData': { + 'value': int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000, + 'dataType': 'DATETIME' + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed_number = matchers.from_raw(self.raw_number) + assert isinstance(parsed_number, matchers.GreaterThanOrEqualMatcher) + assert parsed_number._data_type == 'NUMBER' + assert parsed_number._original_value == 5 + assert parsed_number._value == 5 + assert parsed_number.evaluate('a') is False + assert parsed_number.evaluate([]) is False + assert parsed_number.evaluate({}) is False + assert parsed_number.evaluate(True) is False + assert parsed_number.evaluate(object()) is False + + parsed_date = matchers.from_raw(self.raw_date) + assert isinstance(parsed_date, matchers.GreaterThanOrEqualMatcher) + assert parsed_date._data_type == 'DATETIME' + assert parsed_date._original_value == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45, 0)).timetuple())) * 1000 + assert parsed_date._value == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 0)).timetuple())) + assert parsed_date.evaluate('a') is False + assert parsed_date.evaluate([]) is False + assert parsed_date.evaluate({}) is False + assert parsed_date.evaluate(True) is False + assert parsed_date.evaluate(object()) is False + + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher_number = matchers.GreaterThanOrEqualMatcher(self.raw_number) + assert matcher_number.evaluate(4) is False + assert matcher_number.evaluate(5) is True + assert matcher_number.evaluate(6) is True + assert matcher_number.evaluate('a') is False + assert matcher_number.evaluate([]) is False + assert matcher_number.evaluate({}) is False + assert matcher_number.evaluate(True) is False + assert matcher_number.evaluate(object()) is False + + matcher_date = matchers.GreaterThanOrEqualMatcher(self.raw_date) + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 20, 0, 0, 0)).timetuple()))) is False + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 0, 0, 0)).timetuple()))) is False + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 0)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 22, 0, 0, 0)).timetuple()))) is True + assert matcher_date.evaluate('a') is False + assert matcher_date.evaluate([]) is False + assert matcher_date.evaluate({}) is False + assert matcher_date.evaluate(True) is False + assert matcher_date.evaluate(object()) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json_number = matchers.from_raw(self.raw_number).to_json() + assert as_json_number['unaryNumericMatcherData']['dataType'] == 'NUMBER' + assert as_json_number['unaryNumericMatcherData']['value'] == 5 + assert as_json_number['matcherType'] == 'GREATER_THAN_OR_EQUAL_TO' + assert as_json_number['negate'] is False + + as_json_number = matchers.from_raw(self.raw_date).to_json() + assert as_json_number['unaryNumericMatcherData']['dataType'] == 'DATETIME' + assert as_json_number['unaryNumericMatcherData']['value'] == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000 + assert as_json_number['matcherType'] == 'GREATER_THAN_OR_EQUAL_TO' + assert as_json_number['negate'] is False + + +class LessOrEqualMatcherTests(MatcherTestsBase): + """Test less than or equal matcher.""" + + raw_number = { + 'matcherType': 'LESS_THAN_OR_EQUAL_TO', + 'negate': False, + 'unaryNumericMatcherData': { + 'value': 5, + 'dataType': 'NUMBER' + } + } + + raw_date = { + 'matcherType': 'LESS_THAN_OR_EQUAL_TO', + 'negate': False, + 'unaryNumericMatcherData': { + 'value': int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000, + 'dataType': 'DATETIME' + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed_number = matchers.from_raw(self.raw_number) + assert isinstance(parsed_number, matchers.LessThanOrEqualMatcher) + assert parsed_number._data_type == 'NUMBER' + assert parsed_number._original_value == 5 + assert parsed_number._value == 5 + + parsed_date = matchers.from_raw(self.raw_date) + assert isinstance(parsed_date, matchers.LessThanOrEqualMatcher) + assert parsed_date._data_type == 'DATETIME' + assert parsed_date._original_value == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45, 0)).timetuple())) * 1000 + assert parsed_date._value == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 0)).timetuple())) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher_number = matchers.LessThanOrEqualMatcher(self.raw_number) + assert matcher_number.evaluate(4) is True + assert matcher_number.evaluate(5) is True + assert matcher_number.evaluate(6) is False + assert matcher_number.evaluate('a') is False + assert matcher_number.evaluate([]) is False + assert matcher_number.evaluate({}) is False + assert matcher_number.evaluate(True) is False + assert matcher_number.evaluate(object()) is False + + + matcher_date = matchers.LessThanOrEqualMatcher(self.raw_date) + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 20, 0, 0, 0)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 0, 0, 0)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 0)).timetuple()))) is True + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 21, 9, 31, 45)).timetuple()))) is False + assert matcher_date.evaluate(int(calendar.timegm((datetime(2019, 12, 22, 0, 0, 0)).timetuple()))) is False + assert matcher_date.evaluate('a') is False + assert matcher_date.evaluate([]) is False + assert matcher_date.evaluate({}) is False + assert matcher_date.evaluate(True) is False + assert matcher_date.evaluate(object()) is False + + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json_number = matchers.from_raw(self.raw_number).to_json() + assert as_json_number['unaryNumericMatcherData']['dataType'] == 'NUMBER' + assert as_json_number['unaryNumericMatcherData']['value'] == 5 + assert as_json_number['matcherType'] == 'LESS_THAN_OR_EQUAL_TO' + assert as_json_number['negate'] is False + + as_json_number = matchers.from_raw(self.raw_date).to_json() + assert as_json_number['unaryNumericMatcherData']['dataType'] == 'DATETIME' + assert as_json_number['unaryNumericMatcherData']['value'] == int(calendar.timegm((datetime(2019, 12, 21, 9, 30, 45)).timetuple())) * 1000 + assert as_json_number['matcherType'] == 'LESS_THAN_OR_EQUAL_TO' + assert as_json_number['negate'] is False + + +class UserDefinedSegmentMatcherTests(MatcherTestsBase): + """Test user defined segment matcher.""" + + raw = { + 'matcherType': 'IN_SEGMENT', + 'negate': False, + 'userDefinedSegmentMatcherData': { + 'segmentName': 'some_segment' + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.UserDefinedSegmentMatcher) + assert parsed._segment_name == 'some_segment' + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.UserDefinedSegmentMatcher(self.raw) + segment_storage = mocker.Mock(spec=SegmentStorage) + + # Test that if the key if the storage wrapper finds the key in the segment, it matches. + segment_storage.segment_contains.return_value = True + assert matcher.evaluate('some_key', {}, {'segment_storage': segment_storage}) is True + + # Test that if the key if the storage wrapper doesn't find the key in the segment, it fails. + segment_storage.segment_contains.return_value = False + assert matcher.evaluate('some_key', {}, {'segment_storage': segment_storage}) is False + + assert segment_storage.segment_contains.mock_calls == [ + mocker.call('some_segment', 'some_key'), + mocker.call('some_segment', 'some_key') + ] + + assert matcher.evaluate([], {}, {'segment_storage': segment_storage}) is False + assert matcher.evaluate({}, {}, {'segment_storage': segment_storage}) is False + assert matcher.evaluate(123, {}, {'segment_storage': segment_storage}) is False + assert matcher.evaluate(True, {}, {'segment_storage': segment_storage}) is False + assert matcher.evaluate(False, {}, {'segment_storage': segment_storage}) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.UserDefinedSegmentMatcher(self.raw).to_json() + assert as_json['userDefinedSegmentMatcherData']['segmentName'] == 'some_segment' + assert as_json['matcherType'] == 'IN_SEGMENT' + assert as_json['negate'] is False + + +class WhitelistMatcherTests(MatcherTestsBase): + """Test whitelist matcher.""" + + raw = { + 'matcherType': 'WHITELIST', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.WhitelistMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.WhitelistMatcher(self.raw) + assert matcher.evaluate('key1') is True + assert matcher.evaluate('key2') is True + assert matcher.evaluate('key3') is True + assert matcher.evaluate('key4') is False + assert matcher.evaluate(None) is False + + assert matcher.evaluate([]) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(123) is False + assert matcher.evaluate(True) is False + assert matcher.evaluate(False) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.WhitelistMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'WHITELIST' + assert as_json['negate'] is False + + +class StartsWithMatcherTests(MatcherTestsBase): + """Test StartsWith matcher.""" + + raw = { + 'matcherType': 'STARTS_WITH', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.StartsWithMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.StartsWithMatcher(self.raw) + assert matcher.evaluate('key1AA') is True + assert matcher.evaluate('key2BB') is True + assert matcher.evaluate('key3CC') is True + assert matcher.evaluate('key4DD') is False + assert matcher.evaluate('Akey1A') is False + assert matcher.evaluate(None) is False + assert matcher.evaluate([]) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(123) is False + assert matcher.evaluate(True) is False + assert matcher.evaluate(False) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.StartsWithMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'STARTS_WITH' + assert as_json['negate'] is False + + +class EndsWithMatcherTests(MatcherTestsBase): + """Test EndsWith matcher.""" + + raw = { + 'matcherType': 'ENDS_WITH', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.EndsWithMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.EndsWithMatcher(self.raw) + assert matcher.evaluate('AAkey1') is True + assert matcher.evaluate('BBkey2') is True + assert matcher.evaluate('CCkey3') is True + assert matcher.evaluate('DDkey4') is False + assert matcher.evaluate('Akey1A') is False + assert matcher.evaluate(None) is False + assert matcher.evaluate([]) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(123) is False + assert matcher.evaluate(True) is False + assert matcher.evaluate(False) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.EndsWithMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'ENDS_WITH' + assert as_json['negate'] is False + + +class ContainsStringMatcherTests(MatcherTestsBase): + """Test string matcher.""" + + raw = { + 'matcherType': 'CONTAINS_STRING', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.ContainsStringMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.ContainsStringMatcher(self.raw) + assert matcher.evaluate('AAkey1') is True + assert matcher.evaluate('BBkey2') is True + assert matcher.evaluate('CCkey3') is True + assert matcher.evaluate('Akey1A') is True + assert matcher.evaluate('DDkey4') is False + assert matcher.evaluate('asdsad') is False + assert matcher.evaluate(None) is False + assert matcher.evaluate([]) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(123) is False + assert matcher.evaluate(True) is False + assert matcher.evaluate(False) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.ContainsStringMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'CONTAINS_STRING' + assert as_json['negate'] is False + + +class AllOfSetMatcherTests(MatcherTestsBase): + """Test all of set matcher.""" + + raw = { + 'matcherType': 'CONTAINS_ALL_OF_SET', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.ContainsAllOfSetMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.ContainsAllOfSetMatcher(self.raw) + assert matcher.evaluate(['key1', 'key2', 'key3']) is True + assert matcher.evaluate(['key1', 'key2', 'key3', 'key4']) is True + assert matcher.evaluate(['key4', 'key3', 'key1', 'key5', 'key2']) is True + assert matcher.evaluate(['key1', 'key2']) is False + assert matcher.evaluate([]) is False + assert matcher.evaluate('asdsad') is False + assert matcher.evaluate(3) is False + assert matcher.evaluate(None) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(object()) is False + assert matcher.evaluate(True) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.ContainsAllOfSetMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'CONTAINS_ALL_OF_SET' + assert as_json['negate'] is False + + +class AnyOfSetMatcherTests(MatcherTestsBase): + """Test any of set matcher.""" + + raw = { + 'matcherType': 'CONTAINS_ANY_OF_SET', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.ContainsAnyOfSetMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.ContainsAnyOfSetMatcher(self.raw) + assert matcher.evaluate(['key1', 'key2', 'key3']) is True + assert matcher.evaluate(['key1', 'key2', 'key3', 'key4']) is True + assert matcher.evaluate(['key4', 'key3', 'key1', 'key5', 'key2']) is True + assert matcher.evaluate(['key1', 'key2']) is True + assert matcher.evaluate([]) is False + assert matcher.evaluate('asdsad') is False + assert matcher.evaluate(3) is False + assert matcher.evaluate(None) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(object()) is False + assert matcher.evaluate(True) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.ContainsAnyOfSetMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'CONTAINS_ANY_OF_SET' + assert as_json['negate'] is False + + +class EqualToSetMatcherTests(MatcherTestsBase): + """Test equal to set matcher.""" + + raw = { + 'matcherType': 'EQUAL_TO_SET', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.EqualToSetMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.EqualToSetMatcher(self.raw) + assert matcher.evaluate(['key1', 'key2', 'key3']) is True + assert matcher.evaluate(['key3', 'key2', 'key1']) is True + assert matcher.evaluate(['key1', 'key2', 'key3', 'key4']) is False + assert matcher.evaluate(['key4', 'key3', 'key1', 'key5', 'key2']) is False + assert matcher.evaluate(['key1', 'key2']) is False + assert matcher.evaluate([]) is False + assert matcher.evaluate('asdsad') is False + assert matcher.evaluate(3) is False + assert matcher.evaluate(None) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.EqualToSetMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'EQUAL_TO_SET' + assert as_json['negate'] is False + + +class PartOfSetMatcherTests(MatcherTestsBase): + """Test part of set matcher.""" + + raw = { + 'matcherType': 'PART_OF_SET', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': ['key1', 'key2', 'key3'], + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.PartOfSetMatcher) + assert parsed._whitelist == frozenset(['key1', 'key2', 'key3']) + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + matcher = matchers.PartOfSetMatcher(self.raw) + assert matcher.evaluate(['key1', 'key2', 'key3']) is True + assert matcher.evaluate(['key3', 'key2', 'key1']) is True + assert matcher.evaluate(['key1']) is True + assert matcher.evaluate(['key1', 'key2']) is True + assert matcher.evaluate(['key4', 'key3', 'key1', 'key5', 'key2']) is False + assert matcher.evaluate([]) is False + assert matcher.evaluate('asdsad') is False + assert matcher.evaluate(3) is False + assert matcher.evaluate(None) is False + assert matcher.evaluate({}) is False + assert matcher.evaluate(object()) is False + assert matcher.evaluate(True) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.PartOfSetMatcher(self.raw).to_json() + assert 'key1' in as_json['whitelistMatcherData']['whitelist'] + assert 'key2' in as_json['whitelistMatcherData']['whitelist'] + assert 'key3' in as_json['whitelistMatcherData']['whitelist'] + assert as_json['matcherType'] == 'PART_OF_SET' + assert as_json['negate'] is False + + +class DependencyMatcherTests(MatcherTestsBase): + """tests for dependency matcher.""" + + raw = { + 'matcherType': 'IN_SPLIT_TREATMENT', + 'negate': False, + 'dependencyMatcherData': { + 'split': 'some_split', + 'treatments': ['on', 'almost_on'] + } + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.DependencyMatcher) + assert parsed._split_name == 'some_split' + assert parsed._treatments == ['on', 'almost_on'] + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.DependencyMatcher(self.raw) + evaluator = mocker.Mock(spec=Evaluator) + + evaluator.evaluate_treatment.return_value = {'treatment': 'on'} + assert parsed.evaluate('test1', {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is True + + evaluator.evaluate_treatment.return_value = {'treatment': 'off'} + assert parsed.evaluate('test1', {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + + assert evaluator.evaluate_treatment.mock_calls == [ + mocker.call('some_split', 'test1', 'buck', {}), + mocker.call('some_split', 'test1', 'buck', {}) + ] + + assert parsed.evaluate([], {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + assert parsed.evaluate({}, {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + assert parsed.evaluate(123, {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + assert parsed.evaluate(object(), {}, {'bucketing_key': 'buck', 'evaluator': evaluator}) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.DependencyMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'IN_SPLIT_TREATMENT' + assert as_json['dependencyMatcherData']['split'] == 'some_split' + assert as_json['dependencyMatcherData']['treatments'] == ['on', 'almost_on'] + + +class BooleanMatcherTests(MatcherTestsBase): + """Boolean matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'EQUAL_TO_BOOLEAN', + 'booleanMatcherData': True + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.BooleanMatcher) + assert parsed._data + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + parsed = matchers.BooleanMatcher(self.raw) + assert parsed.evaluate(True) is True + assert parsed.evaluate('true') is True + assert parsed.evaluate('True') is True + assert parsed.evaluate('tRUe') is True + assert parsed.evaluate('dasd') is False + assert parsed.evaluate(123) is False + assert parsed.evaluate(None) is False + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.BooleanMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'EQUAL_TO_BOOLEAN' + assert as_json['booleanMatcherData'] + + +class RegexMatcherTests(MatcherTestsBase): + """Regex matcher test cases.""" + + raw = { + 'negate': False, + 'matcherType': 'MATCHES_STRING', + 'stringMatcherData': "^[a-z][A-Z][0-9]$" + } + + def test_from_raw(self, mocker): + """Test parsing from raw json/dict.""" + parsed = matchers.from_raw(self.raw) + assert isinstance(parsed, matchers.RegexMatcher) + assert parsed._data == "^[a-z][A-Z][0-9]$" + assert parsed._regex == re.compile("^[a-z][A-Z][0-9]$") + + def test_matcher_behaviour(self, mocker): + """Test if the matcher works properly.""" + filename = os.path.join(os.path.dirname(__file__), 'files', 'regex.txt') + with open(filename, 'r') as flo: + test_cases = flo.read().split('\n') + for test_case in test_cases: + if not test_case: + continue + + regex, string, should_match = test_case.split('#') + raw = { + 'negate': False, + 'matcherType': 'MATCHES_STRING', + 'stringMatcherData': regex + } + + parsed = matchers.RegexMatcher(raw) + assert parsed.evaluate(string) == json.loads(should_match) + + def test_to_json(self): + """Test that the object serializes to JSON properly.""" + as_json = matchers.RegexMatcher(self.raw).to_json() + assert as_json['matcherType'] == 'MATCHES_STRING' + assert as_json['stringMatcherData'] == "^[a-z][A-Z][0-9]$" diff --git a/tests/models/grammar/test_partitions.py b/tests/models/grammar/test_partitions.py new file mode 100644 index 00000000..c3c49d12 --- /dev/null +++ b/tests/models/grammar/test_partitions.py @@ -0,0 +1,24 @@ +"""Partitions test module.""" + +from splitio.models.grammar import partitions + +class PartitionTests(object): + """Partition model tests.""" + + raw = { + 'treatment': 'on', + 'size': 50 + } + + def test_parse(self): + """Test that the partition is parsed correctly.""" + p = partitions.from_raw(self.raw) + assert isinstance(p, partitions.Partition) + assert p.treatment == 'on' + assert p.size == 50 + + def test_to_json(self): + """Test the JSON representation.""" + as_json = partitions.from_raw(self.raw).to_json() + assert as_json['treatment'] == 'on' + assert as_json['size'] == 50 diff --git a/tests/models/test_splits.py b/tests/models/test_splits.py new file mode 100644 index 00000000..1f08bf0b --- /dev/null +++ b/tests/models/test_splits.py @@ -0,0 +1,112 @@ +"""Split model tests module.""" + +from splitio.models import splits +from splitio.models.grammar.condition import Condition + + +class SplitTests(object): + """Split model tests.""" + + raw = { + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': 'some_name', + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': 'off', + 'algo': 2, + 'conditions': [ + { + 'partitions': [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'WHITELIST', + 'whitelistMatcherData': { + 'whitelist': ['k1', 'k2', 'k3'] + }, + 'negate': False, + } + ], + 'combiner': 'AND' + } + }, + { + 'partitions': [ + {'treatment': 'on', 'size': 25}, + {'treatment': 'off', 'size': 75} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_other_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'ALL_KEYS', + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + ] + } + + def test_from_raw(self): + """Test split model parsing.""" + parsed = splits.from_raw(self.raw) + assert isinstance(parsed, splits.Split) + assert parsed.change_number == 123 + assert parsed.traffic_type_name == 'user' + assert parsed.name == 'some_name' + assert parsed.traffic_allocation == 100 + assert parsed.traffic_allocation_seed == 123456 + assert parsed.seed == 321654 + assert parsed.status == splits.Status.ACTIVE + assert parsed.killed is False + assert parsed.default_treatment == 'off' + assert parsed.algo == splits.HashAlgorithm.MURMUR + assert len(parsed.conditions) == 2 + + def test_get_segment_names(self, mocker): + """Test fetching segment names.""" + cond1 = mocker.Mock(spec=Condition) + cond2 = mocker.Mock(spec=Condition) + cond1.get_segment_names.return_value = ['segment1', 'segment2'] + cond2.get_segment_names.return_value = ['segment3', 'segment4'] + split1 = splits.Split( 'some_split', 123, False, 'off', 'user', 'ACTIVE', 123, [cond1, cond2]) + assert split1.get_segment_names() == ['segment%d' % i for i in range(1, 5)] + + + def test_to_json(self): + """Test json serialization.""" + as_json = splits.from_raw(self.raw).to_json() + assert isinstance(as_json, dict) + assert as_json['changeNumber'] == 123 + assert as_json['trafficTypeName'] == 'user' + assert as_json['name'] == 'some_name' + assert as_json['trafficAllocation'] == 100 + assert as_json['trafficAllocationSeed'] == 123456 + assert as_json['seed'] == 321654 + assert as_json['status'] == 'ACTIVE' + assert as_json['killed'] is False + assert as_json['defaultTreatment'] == 'off' + assert as_json['algo'] == 2 + assert len(as_json['conditions']) == 2 + + def test_to_split_view(self): + """Test SplitView creation.""" + as_split_view = splits.from_raw(self.raw).to_split_view() + assert isinstance(as_split_view, splits.SplitView) + assert as_split_view.name == self.raw['name'] + assert as_split_view.change_number == self.raw['changeNumber'] + assert as_split_view.killed == self.raw['killed'] + assert as_split_view.traffic_type == self.raw['trafficTypeName'] + assert set(as_split_view.treatments) == set(['on', 'off']) diff --git a/tests/storage/adapters/test_redis_adapter.py b/tests/storage/adapters/test_redis_adapter.py new file mode 100644 index 00000000..9d5253f0 --- /dev/null +++ b/tests/storage/adapters/test_redis_adapter.py @@ -0,0 +1,171 @@ +"""Redis storage adapter test module.""" + + +from splitio.storage.adapters import redis +from redis import StrictRedis +from redis.sentinel import Sentinel + +class RedisStorageAdapterTests(object): + """Redis storage adapter test cases.""" + + def test_forwarding(self, mocker): + """Test that all redis functions forward prefix appropriately.""" + redis_mock = mocker.Mock(StrictRedis) + adapter = redis.RedisAdapter(redis_mock, 'some_prefix') + + redis_mock.keys.return_value = ['some_prefix.key1', 'some_prefix.key2'] + adapter.keys('*') + assert redis_mock.keys.mock_calls[0] == mocker.call('some_prefix.*') + + adapter.set('key1', 'value1') + assert redis_mock.set.mock_calls[0] == mocker.call('some_prefix.key1', 'value1') + + adapter.get('some_key') + assert redis_mock.get.mock_calls[0] == mocker.call('some_prefix.some_key') + + adapter.setex('some_key', 123, 'some_value') + assert redis_mock.setex.mock_calls[0] == mocker.call('some_prefix.some_key', 123, 'some_value') + + adapter.delete('some_key') + assert redis_mock.delete.mock_calls[0] == mocker.call('some_prefix.some_key') + + adapter.mget(['key1', 'key2', 'key3']) + assert redis_mock.mget.mock_calls[0] == mocker.call(['some_prefix.key1', 'some_prefix.key2', 'some_prefix.key3']) + + adapter.sadd('s1', 'value1', 'value2') + assert redis_mock.sadd.mock_calls[0] == mocker.call('some_prefix.s1', 'value1', 'value2') + + adapter.srem('s1', 'value1', 'value2') + assert redis_mock.srem.mock_calls[0] == mocker.call('some_prefix.s1', 'value1', 'value2') + + adapter.sismember('s1', 'value1') + assert redis_mock.sismember.mock_calls[0] == mocker.call('some_prefix.s1', 'value1') + + adapter.eval('script', 3, 'key1', 'key2', 'key3') + assert redis_mock.eval.mock_calls[0] == mocker.call('script', 3, 'some_prefix.key1', 'some_prefix.key2', 'some_prefix.key3') + + adapter.hset('key1', 'name', 'value') + assert redis_mock.hset.mock_calls[0] == mocker.call('some_prefix.key1', 'name', 'value') + + adapter.hget('key1', 'name') + assert redis_mock.hget.mock_calls[0] == mocker.call('some_prefix.key1', 'name') + + adapter.incr('key1') + assert redis_mock.incr.mock_calls[0] == mocker.call('some_prefix.key1', 1) + + adapter.getset('key1', 'new_value') + assert redis_mock.getset.mock_calls[0] == mocker.call('some_prefix.key1', 'new_value') + + adapter.rpush('key1', 'value1', 'value2') + assert redis_mock.rpush.mock_calls[0] == mocker.call('some_prefix.key1', 'value1', 'value2') + + adapter.expire('key1', 10) + assert redis_mock.expire.mock_calls[0] == mocker.call('some_prefix.key1', 10) + + adapter.rpop('key1') + assert redis_mock.rpop.mock_calls[0] == mocker.call('some_prefix.key1') + + adapter.ttl('key1') + assert redis_mock.ttl.mock_calls[0] == mocker.call('some_prefix.key1') + + def test_adapter_building(self, mocker): + """Test buildin different types of client according to parameters received.""" + strict_redis_mock = mocker.Mock(spec=StrictRedis) + sentinel_mock = mocker.Mock(spec=Sentinel) + mocker.patch('splitio.storage.adapters.redis.StrictRedis', new=strict_redis_mock) + mocker.patch('splitio.storage.adapters.redis.Sentinel', new=sentinel_mock) + + config = { + 'redisHost': 'some_host', + 'redisPort': 1234, + 'redisDb': 0, + 'redisPassword': 'some_password', + 'redisSocketTimeout': 123, + 'redisSocketConnectTimeout': 456, + 'redisSocketKeepalive': 789, + 'redisSocketKeepaliveOptions': 10, + 'redisConnectionPool': 20, + 'redisUnixSocketPath': '/tmp/socket', + 'redisEncoding': 'utf-8', + 'redisEncodingErrors': 'strict', + 'redisCharset': 'ascii', + 'redisErrors': 'abc', + 'redisDecodeResponses': True, + 'redisRetryOnTimeout': True, + 'redisSsl': True, + 'redisSslKeyfile': '/ssl.cert', + 'redisSslCertfile': '/ssl2.cert', + 'redisSslCertReqs': 'abc', + 'redisSslCaCerts': 'def', + 'redisMaxConnections': 5, + 'redisPrefix': 'some_prefix' + } + + redis.build(config) + assert strict_redis_mock.mock_calls[0] == mocker.call( + host='some_host', + port=1234, + db=0, + password='some_password', + socket_timeout=123, + socket_connect_timeout=456, + socket_keepalive=789, + socket_keepalive_options=10, + connection_pool=20, + unix_socket_path='/tmp/socket', + encoding='utf-8', + encoding_errors='strict', + charset='ascii', + errors='abc', + decode_responses=True, + retry_on_timeout=True, + ssl=True, + ssl_keyfile='/ssl.cert', + ssl_certfile='/ssl2.cert', + ssl_cert_reqs='abc', + ssl_ca_certs='def', + max_connections=5 + ) + + config = { + 'redisSentinels': [('123.123.123.123', 1), ('456.456.456.456', 2), ('789.789.789.789', 3)], + 'redisMasterService': 'some_master', + 'redisDb': 0, + 'redisPassword': 'some_password', + 'redisSocketTimeout': 123, + 'redisSocketConnectTimeout': 456, + 'redisSocketKeepalive': 789, + 'redisSocketKeepaliveOptions': 10, + 'redisConnectionPool': 20, + 'redisUnixSocketPath': '/tmp/socket', + 'redisEncoding': 'utf-8', + 'redisEncodingErrors': 'strict', + 'redisCharset': 'ascii', + 'redisErrors': 'abc', + 'redisDecodeResponses': True, + 'redisRetryOnTimeout': True, + 'redisSsl': True, + 'redisSslKeyfile': '/ssl.cert', + 'redisSslCertfile': '/ssl2.cert', + 'redisSslCertReqs': 'abc', + 'redisSslCaCerts': 'def', + 'redisMaxConnections': 5, + 'redisPrefix': 'some_prefix' + } + + redis.build(config) + assert sentinel_mock.mock_calls[0] == mocker.call( + [('123.123.123.123', 1), ('456.456.456.456', 2), ('789.789.789.789', 3)], + db=0, + password='some_password', + socket_timeout=123, + socket_connect_timeout=456, + socket_keepalive=789, + socket_keepalive_options=10, + connection_pool=20, + encoding='utf-8', + encoding_errors='strict', + decode_responses=True, + retry_on_timeout=True, + max_connections=5 + ) diff --git a/tests/storage/test_inmemory_storage.py b/tests/storage/test_inmemory_storage.py new file mode 100644 index 00000000..f310f4b1 --- /dev/null +++ b/tests/storage/test_inmemory_storage.py @@ -0,0 +1,268 @@ +"""In-Memory storage test module.""" +#pylint: disable=no-self-use +from splitio.models.splits import Split +from splitio.models.segments import Segment +from splitio.models.impressions import Impression +from splitio.models.events import Event + +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ + InMemoryImpressionStorage, InMemoryEventStorage, InMemoryTelemetryStorage + + +class InMemorySplitStorageTests(object): + """In memory split storage test cases.""" + + def test_storing_retrieving_splits(self, mocker): + """Test storing and retrieving splits works.""" + storage = InMemorySplitStorage() + + split = mocker.Mock(spec=Split) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_split' + type(split).name = name_property + + storage.put(split) + assert storage.get('some_split') == split + assert storage.get_split_names() == ['some_split'] + assert storage.get_all_splits() == [split] + assert storage.get('nonexistant_split') is None + + storage.remove('some_split') + assert storage.get('some_split') is None + + def test_store_get_changenumber(self): + """Test that storing and retrieving change numbers works.""" + storage = InMemorySplitStorage() + assert storage.get_change_number() == -1 + storage.set_change_number(5) + assert storage.get_change_number() == 5 + + def test_get_split_names(self, mocker): + """Test retrieving a list of all split names.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + + storage = InMemorySplitStorage() + storage.put(split1) + storage.put(split2) + + assert set(storage.get_split_names()) == set(['split1', 'split2']) + + def test_get_all_splits(self, mocker): + """Test retrieving a list of all split names.""" + split1 = mocker.Mock() + name1_prop = mocker.PropertyMock() + name1_prop.return_value = 'split1' + type(split1).name = name1_prop + split2 = mocker.Mock() + name2_prop = mocker.PropertyMock() + name2_prop.return_value = 'split2' + type(split2).name = name2_prop + + storage = InMemorySplitStorage() + storage.put(split1) + storage.put(split2) + + all_splits = storage.get_all_splits() + assert next(s for s in all_splits if s.name == 'split1') + assert next(s for s in all_splits if s.name == 'split2') + + +class InMemorySegmentStorageTests(object): + """In memory segment storage tests.""" + + def test_segment_storage_retrieval(self, mocker): + """Test storing and retrieving segments.""" + storage = InMemorySegmentStorage() + segment = mocker.Mock(spec=Segment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment).name = name_property + + storage.put(segment) + assert storage.get('some_segment') == segment + assert storage.get('nonexistant-segment') is None + + def test_change_number(self, mocker): + """Test storing and retrieving segment changeNumber.""" + storage = InMemorySegmentStorage() + storage.set_change_number('some_segment', 123) + # Change number is not updated if segment doesn't exist + assert storage.get_change_number('some_segment') is None + assert storage.get_change_number('nonexistant-segment') is None + + # Change number is updated if segment does exist. + storage = InMemorySegmentStorage() + segment = mocker.Mock(spec=Segment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment).name = name_property + storage.put(segment) + storage.set_change_number('some_segment', 123) + assert storage.get_change_number('some_segment') == 123 + + def test_segment_contains(self, mocker): + """Test using storage to determine whether a key belongs to a segment.""" + storage = InMemorySegmentStorage() + segment = mocker.Mock(spec=Segment) + name_property = mocker.PropertyMock() + name_property.return_value = 'some_segment' + type(segment).name = name_property + storage.put(segment) + + storage.segment_contains('some_segment', 'abc') + assert segment.contains.mock_calls[0] == mocker.call('abc') + + def test_segment_update(self): + """Test updating a segment.""" + storage = InMemorySegmentStorage() + segment = Segment('some_segment', ['key1', 'key2', 'key3'], 123) + storage.put(segment) + assert storage.get('some_segment') == segment + + storage.update('some_segment', ['key4', 'key5'], ['key2', 'key3'], 456) + assert storage.segment_contains('some_segment', 'key1') + assert storage.segment_contains('some_segment', 'key4') + assert storage.segment_contains('some_segment', 'key5') + assert not storage.segment_contains('some_segment', 'key2') + assert not storage.segment_contains('some_segment', 'key3') + assert storage.get_change_number('some_segment') == 456 + + +class InMemoryImpressionsStorageTests(object): + """InMemory impressions storage test cases.""" + + def test_push_pop_impressions(self): + """Test pushing and retrieving impressions.""" + storage = InMemoryImpressionStorage(100) + storage.put([Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) + storage.put([Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) + storage.put([Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654)]) + + # Assert impressions are retrieved in the same order they are inserted. + assert storage.pop_many(1) == [ + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + assert storage.pop_many(1) == [ + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + assert storage.pop_many(1) == [ + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + + # Assert inserting multiple impressions at once works and maintains order. + impressions = [ + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654), + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654), + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + assert storage.put(impressions) + + # Assert impressions are retrieved in the same order they are inserted. + assert storage.pop_many(1) == [ + Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + assert storage.pop_many(1) == [ + Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + assert storage.pop_many(1) == [ + Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ] + + def test_queue_full_hook(self, mocker): + """Test queue_full_hook is executed when the queue is full.""" + storage = InMemoryImpressionStorage(100) + queue_full_hook = mocker.Mock() + storage.set_queue_full_hook(queue_full_hook) + impressions = [ + Impression('key%d' % i, 'feature1', 'on', 'l1', 123456, 'b1', 321654) + for i in range(0, 101) + ] + storage.put(impressions) + assert queue_full_hook.mock_calls == mocker.call() + + +class InMemoryEventsStorageTests(object): + """InMemory events storage test cases.""" + + def test_push_pop_events(self): + """Test pushing and retrieving events.""" + storage = InMemoryEventStorage(100) + storage.put([Event('key1', 'user', 'purchase', 3.5, 123456)]) + storage.put([Event('key2', 'user', 'purchase', 3.5, 123456)]) + storage.put([Event('key3', 'user', 'purchase', 3.5, 123456)]) + + # Assert impressions are retrieved in the same order they are inserted. + assert storage.pop_many(1) == [Event('key1', 'user', 'purchase', 3.5, 123456)] + assert storage.pop_many(1) == [Event('key2', 'user', 'purchase', 3.5, 123456)] + assert storage.pop_many(1) == [Event('key3', 'user', 'purchase', 3.5, 123456)] + + # Assert inserting multiple impressions at once works and maintains order. + events = [ + Event('key1', 'user', 'purchase', 3.5, 123456), + Event('key2', 'user', 'purchase', 3.5, 123456), + Event('key3', 'user', 'purchase', 3.5, 123456), + ] + assert storage.put(events) + + # Assert impressions are retrieved in the same order they are inserted. + assert storage.pop_many(1) == [Event('key1', 'user', 'purchase', 3.5, 123456)] + assert storage.pop_many(1) == [Event('key2', 'user', 'purchase', 3.5, 123456)] + assert storage.pop_many(1) == [Event('key3', 'user', 'purchase', 3.5, 123456)] + + def test_queue_full_hook(self, mocker): + """Test queue_full_hook is executed when the queue is full.""" + storage = InMemoryEventStorage(100) + queue_full_hook = mocker.Mock() + storage.set_queue_full_hook(queue_full_hook) + events = [Event('key%d' % i, 'user', 'purchase', 12.5, 321654) for i in range(0, 101)] + storage.put(events) + assert queue_full_hook.mock_calls == mocker.call() + + +class InMemoryTelemetryStorageTests(object): + """In-Memory telemetry storage unit tests.""" + + def test_latencies(self): + """Test storing and retrieving latencies.""" + storage = InMemoryTelemetryStorage() + storage.inc_latency('sdk.get_treatment', -1) + storage.inc_latency('sdk.get_treatment', 0) + storage.inc_latency('sdk.get_treatment', 1) + storage.inc_latency('sdk.get_treatment', 5) + storage.inc_latency('sdk.get_treatment', 5) + storage.inc_latency('sdk.get_treatment', 22) + latencies = storage.pop_latencies() + assert latencies['sdk.get_treatment'][0] == 1 + assert latencies['sdk.get_treatment'][1] == 1 + assert latencies['sdk.get_treatment'][5] == 2 + assert len(latencies['sdk.get_treatment']) == 22 + assert storage.pop_latencies() == {} + + def test_counters(self): + """Test storing and retrieving counters.""" + storage = InMemoryTelemetryStorage() + storage.inc_counter('some_counter_1') + storage.inc_counter('some_counter_1') + storage.inc_counter('some_counter_1') + storage.inc_counter('some_counter_2') + counters = storage.pop_counters() + assert counters['some_counter_1'] == 3 + assert counters['some_counter_2'] == 1 + assert storage.pop_counters() == {} + + def test_gauges(self): + """Test storing and retrieving gauges.""" + storage = InMemoryTelemetryStorage() + storage.put_gauge('some_gauge_1', 321) + storage.put_gauge('some_gauge_2', 654) + gauges = storage.pop_gauges() + assert gauges['some_gauge_1'] == 321 + assert gauges['some_gauge_2'] == 654 + assert storage.pop_gauges() == {} diff --git a/tests/storage/test_redis.py b/tests/storage/test_redis.py new file mode 100644 index 00000000..d84db121 --- /dev/null +++ b/tests/storage/test_redis.py @@ -0,0 +1,300 @@ +"""Redis storage test module.""" +#pylint: disable=no-self-use +import json +from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ + RedisSegmentStorage, RedisSplitStorage, RedisTelemetryStorage + +from splitio.models.segments import Segment +from splitio.models.impressions import Impression +from splitio.models.events import Event +from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterException + + +class RedisSplitStorageTests(object): + """Redis split storage test cases.""" + + def test_get_split(self, mocker): + """Test retrieving a split works.""" + adapter = mocker.Mock(spec=RedisAdapter) + adapter.get.return_value = '{"name": "some_split"}' + from_raw = mocker.Mock() + mocker.patch('splitio.models.splits.from_raw', new=from_raw) + + storage = RedisSplitStorage(adapter) + storage.get('some_split') + + assert adapter.get.mock_calls == [mocker.call('SPLITIO.split.some_split')] + assert from_raw.mock_calls == [mocker.call({"name": "some_split"})] + + # Test that a missing split returns None and doesn't call from_raw + adapter.reset_mock() + from_raw.reset_mock() + adapter.get.return_value = None + + result = storage.get('some_split') + assert result is None + assert adapter.get.mock_calls == [mocker.call('SPLITIO.split.some_split')] + assert not from_raw.mock_calls + + + def test_get_changenumber(self, mocker): + """Test fetching changenumber.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisSplitStorage(adapter) + adapter.get.return_value = -1 + assert storage.get_change_number() == -1 + assert adapter.get.mock_calls == [mocker.call('SPLITIO.splits.till')] + + def test_get_all_splits(self, mocker): + """Test fetching all splits.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisSplitStorage(adapter) + from_raw = mocker.Mock() + mocker.patch('splitio.models.splits.from_raw', new=from_raw) + + adapter.keys.return_value = [ + 'SPLITIO.split.split1', + 'SPLITIO.split.split2', + 'SPLITIO.split.split3' + ] + def _mget_mock(*_): + return ['{"name": split1}', '{"name": split2}', '{"name": split3}'] + adapter.mget.side_effect = _mget_mock + + storage.get_all_splits() + + assert adapter.keys.mock_calls == [mocker.call('SPLITIO.split.*')] + assert adapter.mget.mock_calls == [ + mocker.call(['SPLITIO.split.split1', 'SPLITIO.split.split2', 'SPLITIO.split.split3']) + ] + + assert len(from_raw.mock_calls) == 3 + assert mocker.call('{"name": split1}') in from_raw.mock_calls + assert mocker.call('{"name": split2}') in from_raw.mock_calls + assert mocker.call('{"name": split3}') in from_raw.mock_calls + + def test_get_split_names(self, mocker): + """Test getching split names.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisSplitStorage(adapter) + adapter.keys.return_value = [ + 'SPLITIO.split.split1', + 'SPLITIO.split.split2', + 'SPLITIO.split.split3' + ] + assert storage.get_split_names() == ['split1', 'split2', 'split3'] + + +class RedisSegmentStorageTests(object): + """Redis segment storage test cases.""" + + def test_fetch_segment(self, mocker): + """Test fetching a whole segment.""" + adapter = mocker.Mock(spec=RedisAdapter) + adapter.smembers.return_value = ["key1", "key2", "key3"] + adapter.get.return_value = 100 + from_raw = mocker.Mock() + mocker.patch('splitio.models.segments.from_raw', new=from_raw) + + storage = RedisSegmentStorage(adapter) + result = storage.get('some_segment') + assert isinstance(result, Segment) + assert result.name == 'some_segment' + assert result.contains('key1') + assert result.contains('key2') + assert result.contains('key3') + assert result.change_number == 100 + assert adapter.smembers.mock_calls == [mocker.call('SPLITIO.segment.some_segment')] + assert adapter.get.mock_calls == [mocker.call('SPLITIO.segment.some_segment.till')] + + # Assert that if segment doesn't exist, None is returned + adapter.reset_mock() + from_raw.reset_mock() + adapter.smembers.return_value = None + assert storage.get('some_segment') is None + assert adapter.smembers.mock_calls == [mocker.call('SPLITIO.segment.some_segment')] + assert adapter.get.mock_calls == [mocker.call('SPLITIO.segment.some_segment.till')] + + def test_fetch_change_number(self, mocker): + """Test fetching change number.""" + adapter = mocker.Mock(spec=RedisAdapter) + adapter.get.return_value = 100 + + storage = RedisSegmentStorage(adapter) + result = storage.get_change_number('some_segment') + assert result == 100 + assert adapter.get.mock_calls == [mocker.call('SPLITIO.segment.some_segment.till')] + + def test_segment_contains(self, mocker): + """Test segment contains functionality.""" + adapter = mocker.Mock(spec=RedisAdapter) + storage = RedisSegmentStorage(adapter) + adapter.sismember.return_value = True + assert storage.segment_contains('some_segment', 'some_key') is True + assert adapter.sismember.mock_calls == [ + mocker.call('SPLITIO.segment.some_segment', 'some_key') + ] + + +class RedisImpressionsStorageTests(object): #pylint: disable=too-few-public-methods + """Redis Events storage test cases.""" + + def test_add_impressions(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapter) + sdk_metadata = { + 'sdk-language-version': 'python-1.2.3', + 'instance-id': 'some_instance_id', + 'ip-address': '123.123.123.123' + } + storage = RedisImpressionsStorage(adapter, sdk_metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + ] + + assert storage.put(impressions) is True + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': sdk_metadata['sdk-language-version'], + 'n': sdk_metadata['instance-id'], + 'i': sdk_metadata['ip-address'], + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + } + }) for impression in impressions] + + assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', to_validate)] + + # Assert that if an exception is thrown it's caught and False is returned + adapter.reset_mock() + def _raise_exc(*_): + raise RedisAdapterException('something') + adapter.rpush.side_effect = _raise_exc + assert storage.put(impressions) is False + + +class RedisEventsStorageTests(object): #pylint: disable=too-few-public-methods + """Redis Impression storage test cases.""" + + def test_add_events(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapter) + sdk_metadata = { + 'sdk-language-version': 'python-1.2.3', + 'instance-id': 'some_instance_id', + 'ip-address': '123.123.123.123' + } + storage = RedisEventsStorage(adapter, sdk_metadata) + + events = [ + Event('key1', 'user', 'purchase', 10, 123456), + Event('key2', 'user', 'purchase', 10, 123456), + Event('key3', 'user', 'purchase', 10, 123456), + Event('key4', 'user', 'purchase', 10, 123456), + ] + assert storage.put(events) is True + + list_of_raw_events = [json.dumps({ + 'm': { # METADATA PORTION + 's': sdk_metadata['sdk-language-version'], + 'n': sdk_metadata['instance-id'], + 'i': sdk_metadata['ip-address'], + }, + 'i': { # IMPRESSION PORTION + 'key': event.key, + 'trafficTypeName': event.traffic_type_name, + 'eventTypeId': event.event_type_id, + 'value': event.value, + 'timestamp': event.timestamp, + } + }) for event in events] + + # To deal with python2 & 3 differences in hashing/order when dumping json. + list_of_raw_json_strings_called = adapter.rpush.mock_calls[0][1][1] + list_of_events_called = [json.loads(event) for event in list_of_raw_json_strings_called] + list_of_events_sent = [json.loads(event) for event in list_of_raw_events] + for item in list_of_events_sent: + assert item in list_of_events_called + +# assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.events', to_validate)] + # Assert that if an exception is thrown it's caught and False is returned + adapter.reset_mock() + def _raise_exc(*_): + raise RedisAdapterException('something') + adapter.rpush.side_effect = _raise_exc + assert storage.put(events) is False + + +class RedisTelemetryStorageTests(object): + """Redis-based telemetry storage test cases.""" + + def test_inc_latency(self, mocker): + """Test incrementing latency.""" + adapter = mocker.Mock(spec=RedisAdapter) + sdk_metadata = { + 'sdk-language-version': 'python-1.2.3', + 'instance-id': 'some_instance_id', + 'ip-address': '123.123.123.123' + } + storage = RedisTelemetryStorage(adapter, sdk_metadata) + storage.inc_latency('some_latency', 0) + storage.inc_latency('some_latency', 1) + storage.inc_latency('some_latency', 5) + storage.inc_latency('some_latency', 5) + storage.inc_latency('some_latency', 22) + assert adapter.incr.mock_calls == [ + mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.0'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.1'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.5'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.5') + ] + + def test_inc_counter(self, mocker): + """Test incrementing latency.""" + adapter = mocker.Mock(spec=RedisAdapter) + sdk_metadata = { + 'sdk-language-version': 'python-1.2.3', + 'instance-id': 'some_instance_id', + 'ip-address': '123.123.123.123' + } + storage = RedisTelemetryStorage(adapter, sdk_metadata) + storage.inc_counter('some_counter_1') + storage.inc_counter('some_counter_1') + storage.inc_counter('some_counter_1') + storage.inc_counter('some_counter_2') + storage.inc_counter('some_counter_2') + assert adapter.incr.mock_calls == [ + mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_1'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_1'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_1'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_2'), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_2') + ] + + def test_inc_gauge(self, mocker): + """Test incrementing latency.""" + adapter = mocker.Mock(spec=RedisAdapter) + sdk_metadata = { + 'sdk-language-version': 'python-1.2.3', + 'instance-id': 'some_instance_id', + 'ip-address': '123.123.123.123' + } + storage = RedisTelemetryStorage(adapter, sdk_metadata) + storage.put_gauge('gauge1', 123) + storage.put_gauge('gauge2', 456) + assert adapter.set.mock_calls == [ + mocker.call('SPLITIO/python-1.2.3/some_instance_id/gauge.gauge1', 123), + mocker.call('SPLITIO/python-1.2.3/some_instance_id/gauge.gauge2', 456) + ] diff --git a/tests/storage/test_uwsgi.py b/tests/storage/test_uwsgi.py new file mode 100644 index 00000000..84fe26c3 --- /dev/null +++ b/tests/storage/test_uwsgi.py @@ -0,0 +1,244 @@ +"""UWSGI Storage unit tests.""" +#pylint: disable=no-self-usage +import json + +from splitio.storage.uwsgi import UWSGIEventStorage, UWSGIImpressionStorage, \ + UWSGISegmentStorage, UWSGISplitStorage, UWSGITelemetryStorage + +from splitio.models.splits import Split +from splitio.models.segments import Segment +from splitio.models.impressions import Impression +from splitio.models.events import Event + +from splitio.storage.adapters.uwsgi_cache import get_uwsgi + + +class UWSGISplitStorageTests(object): + """UWSGI Split Storage test cases.""" + + def test_store_retrieve_split(self, mocker): + """Test storing and retrieving splits.""" + uwsgi = get_uwsgi(True) + storage = UWSGISplitStorage(uwsgi) + split = mocker.Mock(spec=Split) + split.to_json.return_value = '{}' + split_name = mocker.PropertyMock() + split_name.return_value = 'some_split' + type(split).name = split_name + storage.put(split) + + from_raw_mock = mocker.Mock() + from_raw_mock.return_value = 'ok' + mocker.patch('splitio.models.splits.from_raw', new=from_raw_mock) + retrieved = storage.get('some_split') + + assert retrieved == 'ok' + assert from_raw_mock.mock_calls == [mocker.call('{}')] + assert split.to_json.mock_calls == [mocker.call()] + + assert storage.get('nonexistant_split') is None + + storage.remove('some_split') + assert storage.get('some_split') == None + + def test_set_get_changenumber(self, mocker): + """Test setting and retrieving changenumber.""" + uwsgi = get_uwsgi(True) + storage = UWSGISplitStorage(uwsgi) + + assert storage.get_change_number() == None + storage.set_change_number(123) + assert storage.get_change_number() == 123 + + def test_get_split_names(self, mocker): + """Test getting all split names.""" + uwsgi = get_uwsgi(True) + storage = UWSGISplitStorage(uwsgi) + split_1 = mocker.Mock(spec=Split) + split_1.to_json.return_value = '{"name": "split1"}' + split_name_1 = mocker.PropertyMock() + split_name_1.return_value = 'some_split_1' + type(split_1).name = split_name_1 + split_2 = mocker.Mock(spec=Split) + split_2.to_json.return_value = '{"name": "split2"}' + split_name_2 = mocker.PropertyMock() + split_name_2.return_value = 'some_split_2' + type(split_2).name = split_name_2 + storage.put(split_1) + storage.put(split_2) + assert set(storage.get_split_names()) == set(['some_split_1', 'some_split_2']) + storage.remove('some_split_1') + assert storage.get_split_names() == ['some_split_2'] + + def test_get_all_splits(self, mocker): + """Test fetching all splits.""" + uwsgi = get_uwsgi(True) + storage = UWSGISplitStorage(uwsgi) + split_1 = mocker.Mock(spec=Split) + split_1.to_json.return_value = '{"name": "some_split_1"}' + split_name_1 = mocker.PropertyMock() + split_name_1.return_value = 'some_split_1' + type(split_1).name = split_name_1 + split_2 = mocker.Mock(spec=Split) + split_2.to_json.return_value = '{"name": "some_split_2"}' + split_name_2 = mocker.PropertyMock() + split_name_2.return_value = 'some_split_2' + type(split_2).name = split_name_2 + + def _from_raw_mock(split_json): + split_mock = mocker.Mock(spec=Split) + name = mocker.PropertyMock() + name.return_value = json.loads(split_json)['name'] + type(split_mock).name = name + return split_mock + mocker.patch('splitio.storage.uwsgi.splits.from_raw', new=_from_raw_mock) + + storage.put(split_1) + storage.put(split_2) + + splits = storage.get_all_splits() + s1 = next(split for split in splits if split.name == 'some_split_1') + s2 = next(split for split in splits if split.name == 'some_split_2') + + + +class UWSGISegmentStorageTests(object): + """UWSGI Segment storage test cases.""" + + def test_store_retrieve_segment(self, mocker): + """Test storing and fetching segments.""" + uwsgi = get_uwsgi(True) + storage = UWSGISegmentStorage(uwsgi) + segment = mocker.Mock(spec=Segment) + segment_keys = mocker.PropertyMock() + segment_keys.return_value = ['abc'] + type(segment).keys = segment_keys + segment.to_json = {} + segment_name = mocker.PropertyMock() + segment_name.return_value = 'some_segment' + segment_change_number = mocker.PropertyMock() + segment_change_number.return_value = 123 + type(segment).name = segment_name + type(segment).change_number = segment_change_number + from_raw_mock = mocker.Mock() + from_raw_mock.return_value = 'ok' + mocker.patch('splitio.models.segments.from_raw', new=from_raw_mock) + + storage.put(segment) + assert storage.get('some_segment') == 'ok' + assert from_raw_mock.mock_calls == [mocker.call({'till': 123, 'removed': [], 'added': [u'abc'], 'name': 'some_segment'})] + assert storage.get('nonexistant-segment') is None + + def test_get_set_change_number(self, mocker): + """Test setting and getting change number.""" + uwsgi = get_uwsgi(True) + storage = UWSGISegmentStorage(uwsgi) + assert storage.get_change_number('some_segment') is None + storage.set_change_number('some_segment', 123) + assert storage.get_change_number('some_segment') == 123 + + def test_segment_contains(self, mocker): + """Test that segment contains works properly.""" + uwsgi = get_uwsgi(True) + storage = UWSGISegmentStorage(uwsgi) + + from_raw_mock = mocker.Mock() + from_raw_mock.return_value = Segment('some_segment', ['abc'], 123) + mocker.patch('splitio.models.segments.from_raw', new=from_raw_mock) + segment = mocker.Mock(spec=Segment) + segment_keys = mocker.PropertyMock() + segment_keys.return_value = ['abc'] + type(segment).keys = segment_keys + segment.to_json = {} + segment_name = mocker.PropertyMock() + segment_name.return_value = 'some_segment' + segment_change_number = mocker.PropertyMock() + segment_change_number.return_value = 123 + type(segment).name = segment_name + type(segment).change_number = segment_change_number + storage.put(segment) + + assert storage.segment_contains('some_segment', 'abc') + assert not storage.segment_contains('some_segment', 'qwe') + + + +class UWSGIImpressionsStorageTests(object): + """UWSGI Impressions storage test cases.""" + + def test_put_pop_impressions(self, mocker): + """Test storing and fetching impressions.""" + uwsgi = get_uwsgi(True) + storage = UWSGIImpressionStorage(uwsgi) + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + ] + storage.put(impressions) + res = storage.pop_many(10) + assert res == impressions + + def test_flush(self): + """Test requesting, querying and acknowledging a flush.""" + uwsgi = get_uwsgi(True) + storage = UWSGIImpressionStorage(uwsgi) + assert storage.should_flush() is False + storage.request_flush() + assert storage.should_flush() is True + storage.acknowledge_flush() + assert storage.should_flush() is False + + + + +class UWSGIEventsStorageTests(object): + """UWSGI Events storage test cases.""" + + def test_put_pop_events(self, mocker): + """Test storing and fetching events.""" + uwsgi = get_uwsgi(True) + storage = UWSGIEventStorage(uwsgi) + events = [ + Event('key1', 'user', 'purchase', 10, 123456), + Event('key2', 'user', 'purchase', 10, 123456), + Event('key3', 'user', 'purchase', 10, 123456), + Event('key4', 'user', 'purchase', 10, 123456), + ] + + storage.put(events) + res = storage.pop_many(10) + assert res == events + +class UWSGITelemetryStorageTests(object): + """UWSGI-based telemetry storage test cases.""" + + def test_latencies(self): + """Test storing and popping latencies.""" + storage = UWSGITelemetryStorage(get_uwsgi(True)) + storage.inc_latency('some_latency', 2) + storage.inc_latency('some_latency', 2) + storage.inc_latency('some_latency', 2) + assert storage.pop_latencies() == { + 'some_latency': [0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + } + assert storage.pop_latencies() == {} + + def test_counters(self): + """Test storing and popping counters.""" + storage = UWSGITelemetryStorage(get_uwsgi(True)) + storage.inc_counter('some_counter') + storage.inc_counter('some_counter') + storage.inc_counter('some_counter') + assert storage.pop_counters() == {'some_counter': 3} + assert storage.pop_counters() == {} + + def test_gauges(self): + """Test storing and popping gauges.""" + storage = UWSGITelemetryStorage(get_uwsgi(True)) + storage.put_gauge('some_gauge1', 123) + storage.put_gauge('some_gauge2', 456) + assert storage.pop_gauges() == {'some_gauge1': 123, 'some_gauge2': 456} + assert storage.pop_gauges() == {} + diff --git a/tests/tasks/test_events_sync.py b/tests/tasks/test_events_sync.py new file mode 100644 index 00000000..b58c74cb --- /dev/null +++ b/tests/tasks/test_events_sync.py @@ -0,0 +1,40 @@ +"""Impressions synchronization task test module.""" + +import threading +import time +from splitio.api.client import HttpResponse +from splitio.tasks import events_sync +from splitio.storage import EventStorage +from splitio.models.events import Event +from splitio.api.events import EventsAPI + + +class EventsSyncTests(object): + """Impressions Syncrhonization task test cases.""" + + def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + storage = mocker.Mock(spec=EventStorage) + events = [ + Event('key1', 'user', 'purchase', 5.3, 123456), + Event('key2', 'user', 'purchase', 5.3, 123456), + Event('key3', 'user', 'purchase', 5.3, 123456), + Event('key4', 'user', 'purchase', 5.3, 123456), + Event('key5', 'user', 'purchase', 5.3, 123456), + ] + + storage.pop_many.return_value = events + api = mocker.Mock(spec=EventsAPI) + api.flush_events.return_value = HttpResponse(200, '') + task =events_sync.EventsSyncTask(api, storage, 1, 5) + task.start() + time.sleep(2) + assert task.is_running() + assert storage.pop_many.mock_calls[0] == mocker.call(5) + assert api.flush_events.mock_calls[0] == mocker.call(events) + stop_event = threading.Event() + calls_now = len(api.flush_events.mock_calls) + task.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + assert len(api.flush_events.mock_calls) > calls_now diff --git a/tests/tasks/test_impressions_sync.py b/tests/tasks/test_impressions_sync.py new file mode 100644 index 00000000..4851abf9 --- /dev/null +++ b/tests/tasks/test_impressions_sync.py @@ -0,0 +1,38 @@ +"""Impressions synchronization task test module.""" + +import threading +import time +from splitio.api.client import HttpResponse +from splitio.tasks import impressions_sync +from splitio.storage import ImpressionStorage +from splitio.models.impressions import Impression +from splitio.api.impressions import ImpressionsAPI + +class ImpressionsSyncTests(object): + """Impressions Syncrhonization task test cases.""" + + def test_normal_operation(self, mocker): + """Test that the task works properly under normal circumstances.""" + storage = mocker.Mock(spec=ImpressionStorage) + impressions = [ + Impression('key1', 'split1', 'on', 'l1', 123456, 'b1', 321654), + Impression('key2', 'split1', 'on', 'l1', 123456, 'b1', 321654), + Impression('key3', 'split2', 'off', 'l1', 123456, 'b1', 321654), + Impression('key4', 'split2', 'on', 'l1', 123456, 'b1', 321654), + Impression('key5', 'split3', 'off', 'l1', 123456, 'b1', 321654) + ] + storage.pop_many.return_value = impressions + api = mocker.Mock(spec=ImpressionsAPI) + api.flush_impressions.return_value = HttpResponse(200, '') + task = impressions_sync.ImpressionsSyncTask(api, storage, 1, 5) + task.start() + time.sleep(2) + assert task.is_running() + assert storage.pop_many.mock_calls[0] == mocker.call(5) + assert api.flush_impressions.mock_calls[0] == mocker.call(impressions) + stop_event = threading.Event() + calls_now = len(api.flush_impressions.mock_calls) + task.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + assert len(api.flush_impressions.mock_calls) > calls_now diff --git a/tests/tasks/test_segment_sync.py b/tests/tasks/test_segment_sync.py new file mode 100644 index 00000000..025bca79 --- /dev/null +++ b/tests/tasks/test_segment_sync.py @@ -0,0 +1,92 @@ +"""Split syncrhonization task test module.""" + +import threading +import time +from splitio.api import APIException +from splitio.tasks import segment_sync +from splitio.storage import SegmentStorage, SplitStorage +from splitio.models.splits import Split +from splitio.models.segments import Segment +from splitio.models.grammar.condition import Condition +from splitio.models.grammar.matchers import UserDefinedSegmentMatcher + + +class SegmentSynchronizationTests(object): + """Split synchronization task test cases.""" + + def test_normal_operation(self, mocker): + """Test the normal operation flow.""" + split_storage = mocker.Mock(spec=SplitStorage) + split_storage.get_segment_names.return_value = ['segmentA', 'segmentB', 'segmentC'] + + # Setup a mocked segment storage whose changenumber returns -1 on first fetch and + # 123 afterwards. + storage = mocker.Mock(spec=SegmentStorage) + def change_number_mock(segment_name): + if segment_name == 'segmentA' and change_number_mock._count_a == 0: + change_number_mock._count_a = 1 + return -1 + if segment_name == 'segmentB' and change_number_mock._count_b == 0: + change_number_mock._count_b = 1 + return -1 + if segment_name == 'segmentC' and change_number_mock._count_c == 0: + change_number_mock._count_c = 1 + return -1 + return 123 + change_number_mock._count_a = 0 + change_number_mock._count_b = 0 + change_number_mock._count_c = 0 + storage.get_change_number.side_effect = change_number_mock + + # Setup a mocked segment api to return segments mentioned before. + def fetch_segment_mock(segment_name, change_number): + if segment_name == 'segmentA' and fetch_segment_mock._count_a == 0: + fetch_segment_mock._count_a = 1 + return {'name': 'segmentA', 'added': ['key1', 'key2', 'key3'], 'removed': [], 'since': -1, 'till': 123} + if segment_name == 'segmentB' and fetch_segment_mock._count_b == 0: + fetch_segment_mock._count_b = 1 + return {'name': 'segmentB', 'added': ['key4', 'key5', 'key6'], 'removed': [], 'since': -1, 'till': 123} + if segment_name == 'segmentC' and fetch_segment_mock._count_c == 0: + fetch_segment_mock._count_c = 1 + return {'name': 'segmentC', 'added': ['key7', 'key8', 'key9'], 'removed': [], 'since': -1, 'till': 123} + return {'added': [], 'removed': [], 'since': 123, 'till': 123} + fetch_segment_mock._count_a = 0 + fetch_segment_mock._count_b = 0 + fetch_segment_mock._count_c = 0 + + api = mocker.Mock() + api.fetch_segment.side_effect = fetch_segment_mock + + segment_ready_event = threading.Event() + task = segment_sync.SegmentSynchronizationTask(api, storage, split_storage, 1, segment_ready_event) + task.start() + + segment_ready_event.wait(5) + assert task.is_running() + + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait() + assert segment_ready_event.is_set() + assert not task.is_running() + + api_calls = [call for call in api.fetch_segment.mock_calls] + assert mocker.call('segmentA', -1) in api_calls + assert mocker.call('segmentB', -1) in api_calls + assert mocker.call('segmentC', -1) in api_calls + assert mocker.call('segmentA', 123) in api_calls + assert mocker.call('segmentB', 123) in api_calls + assert mocker.call('segmentC', 123) in api_calls + + segment_put_calls = storage.put.mock_calls + segments_to_validate = set(['segmentA', 'segmentB', 'segmentC']) + for call in segment_put_calls: + func_name, positional_args, keyword_args = call + segment = positional_args[0] + assert isinstance(segment, Segment) + assert segment.name in segments_to_validate + segments_to_validate.remove(segment.name) + + def test_that_errors_dont_stop_task(self, mocker): + """Test that if fetching segments fails at some_point, the task will continue running.""" + # TODO! diff --git a/tests/tasks/test_split_sync.py b/tests/tasks/test_split_sync.py new file mode 100644 index 00000000..31f70a02 --- /dev/null +++ b/tests/tasks/test_split_sync.py @@ -0,0 +1,119 @@ +"""Split syncrhonization task test module.""" + +import threading +import time +from splitio.api import APIException +from splitio.tasks import split_sync +from splitio.storage import SplitStorage +from splitio.models.splits import Split + + +class SplitSynchronizationTests(object): + """Split synchronization task test cases.""" + + def test_normal_operation(self, mocker): + """Test the normal operation flow.""" + storage = mocker.Mock(spec=SplitStorage) + def change_number_mock(): + change_number_mock._calls += 1 + if change_number_mock._calls == 1: + return -1 + return 123 + change_number_mock._calls = 0 + storage.get_change_number.side_effect = change_number_mock + + api = mocker.Mock() + splits = [{ + 'changeNumber': 123, + 'trafficTypeName': 'user', + 'name': 'some_name', + 'trafficAllocation': 100, + 'trafficAllocationSeed': 123456, + 'seed': 321654, + 'status': 'ACTIVE', + 'killed': False, + 'defaultTreatment': 'off', + 'algo': 2, + 'conditions': [ + { + 'partitions': [ + {'treatment': 'on', 'size': 50}, + {'treatment': 'off', 'size': 50} + ], + 'contitionType': 'WHITELIST', + 'label': 'some_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'WHITELIST', + 'whitelistMatcherData': { + 'whitelist': ['k1', 'k2', 'k3'] + }, + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + ] + }] + + def get_changes(*args, **kwargs): + get_changes.called += 1 + + if get_changes.called == 1: + return { + 'splits': splits, + 'since': -1, + 'till': 123 + } + else: + return { + 'splits': [], + 'since': 123, + 'till': 123 + } + get_changes.called = 0 + + api.fetch_splits.side_effect = get_changes + splits_ready_event = threading.Event() + task = split_sync.SplitSynchronizationTask(api, storage, 1, splits_ready_event) + task.start() + splits_ready_event.wait(5) + assert task.is_running() + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait() + assert splits_ready_event.is_set() + assert not task.is_running() + api_calls = api.fetch_splits.mock_calls + assert mocker.call(-1) in api.fetch_splits.mock_calls + assert mocker.call(123) in api.fetch_splits.mock_calls + + inserted_split = storage.put.mock_calls[0][1][0] + assert isinstance(inserted_split, Split) + assert inserted_split.name == 'some_name' + + def test_that_errors_dont_stop_task(self, mocker): + """Test that if fetching splits fails at some_point, the task will continue running.""" + storage = mocker.Mock(spec=SplitStorage) + api = mocker.Mock() + def run(x): + run._calls +=1 + if run._calls == 1: + return {'splits': [], 'since': -1, 'till': -1} + if run._calls == 2: + return {'splits': [], 'since': -1, 'till': -1} + raise APIException("something broke") + run._calls = 0 + api.fetch_splits.side_effect = run + storage.get_change_number.return_value = -1 + + splits_ready_event = threading.Event() + task = split_sync.SplitSynchronizationTask(api, storage, 0.5, splits_ready_event) + task.start() + splits_ready_event.wait(5) + assert task.is_running() + time.sleep(1) + assert task.is_running() + task.stop() diff --git a/tests/tasks/test_telemetry_sync.py b/tests/tasks/test_telemetry_sync.py new file mode 100644 index 00000000..ab67463d --- /dev/null +++ b/tests/tasks/test_telemetry_sync.py @@ -0,0 +1,57 @@ +"""Telemetry synchronization task unit test module.""" +#pylint: disable=no-self-use +import time +import threading +from splitio.storage import TelemetryStorage +from splitio.api.telemetry import TelemetryAPI +from splitio.tasks.telemetry_sync import TelemetrySynchronizationTask + + +class TelemetrySyncTests(object): #pylint: disable=too-few-public-methods + """Impressions Syncrhonization task test cases.""" + + def test_normal_operation(self, mocker): + """Test normal behaviour of sync task.""" + api = mocker.Mock(spec=TelemetryAPI) + storage = mocker.Mock(spec=TelemetryStorage) + storage.pop_latencies.return_value = { + 'some_latency1': [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'some_latency2': [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + } + storage.pop_gauges.return_value = { + 'gauge1': 123, + 'gauge2': 456 + } + storage.pop_counters.return_value = { + 'counter1': 1, + 'counter2': 5 + } + task = TelemetrySynchronizationTask(api, storage, 1) + task.start() + time.sleep(2) + assert task.is_running() + + stop_event = threading.Event() + task.stop(stop_event) + stop_event.wait() + + assert stop_event.is_set() + assert not task.is_running() + assert mocker.call() in storage.pop_latencies.mock_calls + assert mocker.call() in storage.pop_counters.mock_calls + assert mocker.call() in storage.pop_gauges.mock_calls + + assert mocker.call({ + 'some_latency1': [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'some_latency2': [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + }) in api.flush_latencies.mock_calls + + assert mocker.call({ + 'gauge1': 123, + 'gauge2': 456 + }) in api.flush_gauges.mock_calls + + assert mocker.call({ + 'counter1': 1, + 'counter2': 5 + }) in api.flush_counters.mock_calls diff --git a/tests/tasks/util/test_asynctask.py b/tests/tasks/util/test_asynctask.py new file mode 100644 index 00000000..a22b4b45 --- /dev/null +++ b/tests/tasks/util/test_asynctask.py @@ -0,0 +1,118 @@ +"""Asynctask test module.""" + +import time +import threading +from splitio.tasks.util import asynctask + + +class AsyncTaskTests(object): + """AsyncTask test cases.""" + + def test_default_task_flow(self, mocker): + """Test the default execution flow of an asynctask.""" + main_func = mocker.Mock() + on_init = mocker.Mock() + on_stop = mocker.Mock() + on_stop_event = threading.Event() + + task = asynctask.AsyncTask(main_func, 0.5, on_init, on_stop) + task.start() + time.sleep(1) + assert task.running() + task.stop(on_stop_event) + on_stop_event.wait() + + assert on_stop_event.is_set() + assert 0 < len(main_func.mock_calls) <= 2 + assert len(on_init.mock_calls) == 1 + assert len(on_stop.mock_calls) == 1 + assert not task.running() + + def test_main_exception_skips_iteration(self, mocker): + """Test that an exception in the main func only skips current iteration.""" + def raise_exception(): + raise Exception('something') + main_func = mocker.Mock() + main_func.side_effect = raise_exception + on_init = mocker.Mock() + on_stop = mocker.Mock() + on_stop_event = threading.Event() + + task = asynctask.AsyncTask(main_func, 0.1, on_init, on_stop) + task.start() + time.sleep(1) + assert task.running() + task.stop(on_stop_event) + on_stop_event.wait() + + assert on_stop_event.is_set() + assert 9 <= len(main_func.mock_calls) <= 10 + assert len(on_init.mock_calls) == 1 + assert len(on_stop.mock_calls) == 1 + assert not task.running() + + def test_on_init_failure_aborts_task(self, mocker): + """Test that if the on_init callback fails, the task never runs.""" + def raise_exception(): + raise Exception('something') + main_func = mocker.Mock() + on_init = mocker.Mock() + on_init.side_effect = raise_exception + on_stop = mocker.Mock() + on_stop_event = threading.Event() + + task = asynctask.AsyncTask(main_func, 0.1, on_init, on_stop) + task.start() + time.sleep(0.5) + assert not task.running() # Since on_init fails, task never starts + task.stop(on_stop_event) + on_stop_event.wait(1) + + assert on_stop_event.is_set() + assert on_init.mock_calls == [mocker.call()] + assert on_stop.mock_calls == [mocker.call()] + assert main_func.mock_calls == [] + assert not task.running() + + def test_on_stop_failure_ends_gacefully(self, mocker): + """Test that if the on_init callback fails, the task never runs.""" + def raise_exception(): + raise Exception('something') + main_func = mocker.Mock() + on_init = mocker.Mock() + on_stop = mocker.Mock() + on_stop.side_effect = raise_exception + on_stop_event = threading.Event() + + task = asynctask.AsyncTask(main_func, 0.1, on_init, on_stop) + task.start() + time.sleep(1) + task.stop(on_stop_event) + on_stop_event.wait(1) + + assert on_stop_event.isSet() + assert on_init.mock_calls == [mocker.call()] + assert on_stop.mock_calls == [mocker.call()] + assert 9 <= len(main_func.mock_calls) <= 10 + + def test_force_run(self, mocker): + """Test that if the on_init callback fails, the task never runs.""" + main_func = mocker.Mock() + on_init = mocker.Mock() + on_stop = mocker.Mock() + on_stop_event = threading.Event() + + task = asynctask.AsyncTask(main_func, 5, on_init, on_stop) + task.start() + time.sleep(1) + assert task.running() + task.force_execution() + task.force_execution() + task.stop(on_stop_event) + on_stop_event.wait(1) + + assert on_stop_event.isSet() + assert on_init.mock_calls == [mocker.call()] + assert on_stop.mock_calls == [mocker.call()] + assert len(main_func.mock_calls) == 2 + assert not task.running() diff --git a/tests/tasks/util/test_workerpool.py b/tests/tasks/util/test_workerpool.py new file mode 100644 index 00000000..c7811ad0 --- /dev/null +++ b/tests/tasks/util/test_workerpool.py @@ -0,0 +1,53 @@ +"""Workerpool test module.""" +import time +import threading +from splitio.tasks.util import workerpool + + +class WorkerPoolTests(object): + """Worker pool test cases.""" + + def test_normal_operation(self, mocker): + """Test normal opeation works properly.""" + worker_func = mocker.Mock() + wp = workerpool.WorkerPool(10, worker_func) + wp.start() + for x in range(0, 100): + wp.submit_work(str(x)) + + stop_event = threading.Event() + wp.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + + calls = worker_func.mock_calls + for x in range(0, 100): + assert mocker.call(str(x)) in calls + + def test_failure_in_message_doesnt_breal(self, mocker): + """Test that if a message cannot be parsed it is ignored and others are processed.""" + class Worker: + def __init__(self): + self._worked = set() + + def do_work(self, w): + if w == '55': + raise Exception('something') + self._worked.add(w) + + worker = Worker() + wp = workerpool.WorkerPool(50, worker.do_work) + wp.start() + for x in range(0, 100): + wp.submit_work(str(x)) + + stop_event = threading.Event() + wp.stop(stop_event) + stop_event.wait(5) + assert stop_event.is_set() + + for x in range(0, 100): + if x != 55: + assert str(x) in worker._worked + else: + assert str(x) not in worker._worked From 5b8b7abd59a17272a211924ce72b32102ea7836f Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Fri, 12 Apr 2019 09:18:40 -0300 Subject: [PATCH 02/38] add configurations to split model --- splitio/models/splits.py | 15 ++++++++++++--- tests/models/test_splits.py | 7 ++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/splitio/models/splits.py b/splitio/models/splits.py index 170fe285..8a652fab 100644 --- a/splitio/models/splits.py +++ b/splitio/models/splits.py @@ -43,7 +43,8 @@ def __init__( #pylint: disable=too-many-arguments conditions=None, algo=None, traffic_allocation=None, - traffic_allocation_seed=None + traffic_allocation_seed=None, + configurations=None ): """ Class constructor. @@ -91,6 +92,8 @@ def __init__( #pylint: disable=too-many-arguments except ValueError: self._algo = HashAlgorithm.LEGACY + self._configurations = configurations + @property def name(self): """Return name.""" @@ -146,6 +149,10 @@ def traffic_allocation_seed(self): """Return the traffic allocation seed of the split.""" return self._traffic_allocation_seed + def get_configurations_for(self, treatment): + """Return the mapping of treatments to configurations.""" + return self._configurations.get(treatment) if self._configurations else None + def get_segment_names(self): """ Return a list of segment names referenced in all matchers from this split. @@ -168,7 +175,8 @@ def to_json(self): 'killed': self.killed, 'defaultTreatment': self.default_treatment, 'algo': self.algo.value, - 'conditions': [c.to_json() for c in self.conditions] + 'conditions': [c.to_json() for c in self.conditions], + 'configurations': self._configurations } def to_split_view(self): @@ -219,5 +227,6 @@ def from_raw(raw_split): [condition.from_raw(c) for c in raw_split['conditions']], raw_split.get('algo'), traffic_allocation=raw_split.get('trafficAllocation'), - traffic_allocation_seed=raw_split.get('trafficAllocationSeed') + traffic_allocation_seed=raw_split.get('trafficAllocationSeed'), + configurations=raw_split.get('configurations') ) diff --git a/tests/models/test_splits.py b/tests/models/test_splits.py index 1f08bf0b..847448b0 100644 --- a/tests/models/test_splits.py +++ b/tests/models/test_splits.py @@ -56,7 +56,10 @@ class SplitTests(object): 'combiner': 'AND' } } - ] + ], + 'configurations': { + 'on': '{"color": "blue", "size": 13}' + }, } def test_from_raw(self): @@ -74,6 +77,8 @@ def test_from_raw(self): assert parsed.default_treatment == 'off' assert parsed.algo == splits.HashAlgorithm.MURMUR assert len(parsed.conditions) == 2 + assert parsed.get_configurations_for('on') == '{"color": "blue", "size": 13}' + assert parsed._configurations == {'on': '{"color": "blue", "size": 13}'} def test_get_segment_names(self, mocker): """Test fetching segment names.""" From 94a941b66081d5d77f9f85a7d4a85ff2e33e488b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Fri, 12 Apr 2019 09:30:31 -0300 Subject: [PATCH 03/38] return configurations from evaluator --- splitio/engine/evaluator.py | 1 + tests/engine/test_evaluator.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/splitio/engine/evaluator.py b/splitio/engine/evaluator.py index 881ffcac..0abcafa9 100644 --- a/splitio/engine/evaluator.py +++ b/splitio/engine/evaluator.py @@ -76,6 +76,7 @@ def evaluate_treatment(self, feature, matching_key, return { 'treatment': _treatment, + 'configurations': split.get_configurations_for(_treatment) if split else None, 'impression': { 'label': label, 'change_number': _change_number diff --git a/tests/engine/test_evaluator.py b/tests/engine/test_evaluator.py index 4c5e936b..0982be42 100644 --- a/tests/engine/test_evaluator.py +++ b/tests/engine/test_evaluator.py @@ -25,6 +25,7 @@ def test_evaluate_treatment_missing_split(self, mocker): e = self._build_evaluator_with_mocks(mocker) e._split_storage.get.return_value = None result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) + assert result['configurations'] == None assert result['treatment'] == evaluator.CONTROL assert result['impression']['change_number'] == -1 assert result['impression']['label'] == Label.SPLIT_NOT_FOUND @@ -36,11 +37,14 @@ def test_evaluate_treatment_killed_split(self, mocker): mocked_split.default_treatment = 'off' mocked_split.killed = True mocked_split.change_number = 123 + mocked_split.get_configurations_for.return_value = '{"some_property": 123}' e._split_storage.get.return_value = mocked_split result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) assert result['treatment'] == 'off' + assert result['configurations'] == '{"some_property": 123}' assert result['impression']['change_number'] == 123 assert result['impression']['label'] == Label.KILLED + assert mocked_split.get_configurations_for.mock_calls == [mocker.call('off')] def test_evaluate_treatment_ok(self, mocker): """Test that a non-killed split returns the appropriate treatment.""" @@ -51,13 +55,16 @@ def test_evaluate_treatment_ok(self, mocker): mocked_split.default_treatment = 'off' mocked_split.killed = False mocked_split.change_number = 123 + mocked_split.get_configurations_for.return_value = '{"some_property": 123}' e._split_storage.get.return_value = mocked_split result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) assert result['treatment'] == 'on' + assert result['configurations'] == '{"some_property": 123}' assert result['impression']['change_number'] == 123 assert result['impression']['label'] == 'some_label' + assert mocked_split.get_configurations_for.mock_calls == [mocker.call('on')] - def test_evaluate_treatment_ok(self, mocker): + def test_evaluate_treatment_ok_no_config(self, mocker): """Test that a killed split returns the default treatment.""" e = self._build_evaluator_with_mocks(mocker) e._get_treatment_for_split = mocker.Mock() @@ -66,11 +73,14 @@ def test_evaluate_treatment_ok(self, mocker): mocked_split.default_treatment = 'off' mocked_split.killed = False mocked_split.change_number = 123 + mocked_split.get_configurations_for.return_value = None e._split_storage.get.return_value = mocked_split result = e.evaluate_treatment('feature1', 'some_key', 'some_bucketing_key', {'attr1': 1}) assert result['treatment'] == 'on' + assert result['configurations'] == None assert result['impression']['change_number'] == 123 assert result['impression']['label'] == 'some_label' + assert mocked_split.get_configurations_for.mock_calls == [mocker.call('on')] def test_get_gtreatment_for_split_no_condition_matches(self, mocker): """Test no condition matches.""" @@ -102,7 +112,7 @@ def test_get_gtreatment_for_split_non_rollout(self, mocker): assert treatment == 'on' assert label == 'some_label' - def test_get_gtreatment_for_split_rollout(self, mocker): + def test_get_treatment_for_split_rollout(self, mocker): """Test rollout condition returns default treatment.""" e = self._build_evaluator_with_mocks(mocker) e._splitter.get_bucket.return_value = 60 From 66f5d295d64fe8e9e72e1efd6477630a0ce7c70e Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Fri, 12 Apr 2019 15:39:55 -0300 Subject: [PATCH 04/38] update clients & input validator to work with new methods *_with_config --- splitio/client/client.py | 67 +++-- splitio/client/input_validator.py | 47 ++- splitio/client/util.py | 22 ++ tests/client/test_client.py | 133 ++++++++ tests/client/test_input_validator.py | 433 +++++++++++++++++++++++---- 5 files changed, 605 insertions(+), 97 deletions(-) diff --git a/splitio/client/client.py b/splitio/client/client.py index ecee50a3..54aa2307 100644 --- a/splitio/client/client.py +++ b/splitio/client/client.py @@ -4,6 +4,7 @@ import logging import time +import six from splitio.engine.evaluator import Evaluator, CONTROL from splitio.engine.splitters import Splitter from splitio.models.impressions import Impression, Label @@ -79,36 +80,21 @@ def _send_impression_to_listener(self, impression, attributes): ) self._logger.debug('Error', exc_info=True) - def get_treatment(self, key, feature, attributes=None): - """ - Get the treatment for a feature and key, with an optional dictionary of attributes. - - This method never raises an exception. If there's a problem, the appropriate log message - will be generated and the method will return the CONTROL treatment. - - :param key: The key for which to get the treatment - :type key: str - :param feature: The name of the feature for which to get the treatment - :type feature: str - :param attributes: An optional dictionary of attributes - :type attributes: dict - :return: The treatment for the key and feature - :rtype: str - """ + def get_treatment_with_config(self, key, feature, attributes=None): try: if self.destroyed: self._logger.error("Client has already been destroyed - no calls possible") - return CONTROL + return CONTROL, None start = int(round(time.time() * 1000)) - matching_key, bucketing_key = input_validator.validate_key(key, 'get_treatment') + matching_key, bucketing_key = input_validator.validate_key(key) feature = input_validator.validate_feature_name(feature) if (matching_key is None and bucketing_key is None) \ or feature is None \ - or not input_validator.validate_attributes(attributes, 'get_treatment'): - return CONTROL + or not input_validator.validate_attributes(attributes): + return CONTROL, None result = self._evaluator.evaluate_treatment( feature, @@ -129,7 +115,7 @@ def get_treatment(self, key, feature, attributes=None): self._record_stats(impression, start, self._METRIC_GET_TREATMENT) self._send_impression_to_listener(impression, attributes) - return result['treatment'] + return result['treatment'], result['configurations'] except Exception: #pylint: disable=broad-except self._logger.error('Error getting treatment for feature') self._logger.debug('Error: ', exc_info=True) @@ -148,9 +134,28 @@ def get_treatment(self, key, feature, attributes=None): except Exception: # pylint: disable=broad-except self._logger.error('Error reporting impression into get_treatment exception block') self._logger.debug('Error: ', exc_info=True) - return CONTROL + return CONTROL, None - def get_treatments(self, key, features, attributes=None): + def get_treatment(self, key, feature, attributes=None): + """ + Get the treatment for a feature and key, with an optional dictionary of attributes. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param feature: The name of the feature for which to get the treatment + :type feature: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :return: The treatment for the key and feature + :rtype: str + """ + treatment, _ = self.get_treatment_with_config(key, feature, attributes) + return treatment + + def get_treatments_with_config(self, key, features, attributes=None): """ Evaluate multiple features and return a dictionary with all the feature/treatments. @@ -172,11 +177,11 @@ def get_treatments(self, key, features, attributes=None): start = int(round(time.time() * 1000)) - matching_key, bucketing_key = input_validator.validate_key(key, 'get_treatments') + matching_key, bucketing_key = input_validator.validate_key(key) if matching_key is None and bucketing_key is None: return input_validator.generate_control_treatments(features) - if input_validator.validate_attributes(attributes, 'get_treatments') is False: + if input_validator.validate_attributes(attributes) is False: return input_validator.generate_control_treatments(features) features = input_validator.validate_features_get_treatments(features) @@ -204,13 +209,15 @@ def get_treatments(self, key, features, attributes=None): start) bulk_impressions.append(impression) - treatments[feature] = treatment['treatment'] + treatments[feature] = (treatment['treatment'], treatment['configurations']) except Exception: #pylint: disable=broad-except self._logger.error('get_treatments: An exception occured when evaluating ' 'feature ' + feature + ' returning CONTROL.') - treatments[feature] = CONTROL + treatments[feature] = CONTROL, None self._logger.debug('Error: ', exc_info=True) + import traceback + traceback.print_exc() continue # Register impressions @@ -226,6 +233,12 @@ def get_treatments(self, key, features, attributes=None): return treatments + + def get_treatments(self, key, features, attributes=None): + """TODO""" + with_config = self.get_treatments_with_config(key, features, attributes) + return {feature: result[0] for (feature, result) in six.iteritems(with_config)} + def _build_impression( #pylint: disable=too-many-arguments self, matching_key, diff --git a/splitio/client/input_validator.py b/splitio/client/input_validator.py index b692c011..2c06f0f3 100644 --- a/splitio/client/input_validator.py +++ b/splitio/client/input_validator.py @@ -8,6 +8,7 @@ import math import requests from splitio.client.key import Key +from splitio.client.util import get_calls from splitio.engine.evaluator import CONTROL # from splitio.api import SdkApi # from splitio.exceptions import NetworkingException @@ -18,6 +19,22 @@ EVENT_TYPE_PATTERN = r'^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$' +def _get_first_split_sdk_call(): + """ + Get the method name of the original call on the SplitClient methods. + + :return: Name of the method called by the user. + :rtype: str + """ + unknown_method = 'unknown-method' + try: + calls = get_calls(['Client', 'SplitManager']) + if calls: + return calls[-1] + return unknown_method + except Exception: #pylint: disable=broad-except + return unknown_method + def _check_not_null(value, name, operation): """ Checks if value is null @@ -194,7 +211,7 @@ def _remove_empty_spaces(value, operation): return strip_value -def validate_key(key, operation): +def validate_key(key): """ Validate Key parameter for get_treatment/s, if is invalid at some point the bucketing_key or matching_key it will return None @@ -206,6 +223,7 @@ def validate_key(key, operation): :return: The tuple key :rtype: (matching_key,bucketing_key) """ + operation = _get_first_split_sdk_call() matching_key_result = None bucketing_key_result = None if key is None: @@ -239,11 +257,12 @@ def validate_feature_name(feature_name): :return: feature_name :rtype: str|None """ - if (not _check_not_null(feature_name, 'feature_name', 'get_treatment')) or \ - (not _check_is_string(feature_name, 'feature_name', 'get_treatment')) or \ - (not _check_string_not_empty(feature_name, 'feature_name', 'get_treatment')): + operation = _get_first_split_sdk_call() + if (not _check_not_null(feature_name, 'feature_name', operation)) or \ + (not _check_is_string(feature_name, 'feature_name', operation)) or \ + (not _check_string_not_empty(feature_name, 'feature_name', operation)): return None - return _remove_empty_spaces(feature_name, 'get_treatment') + return _remove_empty_spaces(feature_name, operation) def validate_track_key(key): @@ -344,19 +363,20 @@ def validate_features_get_treatments(features): :return: filtered_features :rtype: list|None """ + operation = _get_first_split_sdk_call() if features is None or not isinstance(features, list): - _LOGGER.error('get_treatments: feature_names must be a non-empty array.') + _LOGGER.error("%s: feature_names must be a non-empty array." % operation) return None if len(features) == 0: - _LOGGER.error('get_treatments: feature_names must be a non-empty array.') + _LOGGER.error("%s: feature_names must be a non-empty array." % operation) return [] - filtered_features = set(_remove_empty_spaces(feature, 'get_treatments') for feature in features + filtered_features = set(_remove_empty_spaces(feature, operation) for feature in features if feature is not None and - _check_is_string(feature, 'feature_name', 'get_treatments') and - _check_string_not_empty(feature, 'feature_name', 'get_treatments') + _check_is_string(feature, 'feature_name', operation) and + _check_string_not_empty(feature, 'feature_name', operation) ) if len(filtered_features) == 0: - _LOGGER.error('get_treatments: feature_names must be a non-empty array.') + _LOGGER.error("%s: feature_names must be a non-empty array." % operation) return None return filtered_features @@ -370,10 +390,10 @@ def generate_control_treatments(features): :return: dict :rtype: dict|None """ - return {feature: CONTROL for feature in validate_features_get_treatments(features)} + return {feature: (CONTROL, None) for feature in validate_features_get_treatments(features)} -def validate_attributes(attributes, operation): +def validate_attributes(attributes): """ Checks if attributes is valid @@ -384,6 +404,7 @@ def validate_attributes(attributes, operation): :return: bool :rtype: True|False """ + operation = _get_first_split_sdk_call() if attributes is None: return True if not type(attributes) is dict: diff --git a/splitio/client/util.py b/splitio/client/util.py index d92bdb20..d0465f86 100644 --- a/splitio/client/util.py +++ b/splitio/client/util.py @@ -1,5 +1,6 @@ """General purpose SDK utilities.""" +import inspect import socket from collections import namedtuple from splitio.version import __version__ @@ -10,6 +11,8 @@ ) + + def _get_ip(): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: @@ -26,9 +29,28 @@ def _get_ip(): def _get_hostname(ip_address): return 'unknown' if ip_address == 'unknown' else 'ip-' + ip_address.replace('.', '-') + def get_metadata(*args, **kwargs): """Gather SDK metadata and return a tuple with such info.""" version = 'python-%s' % __version__ ip_address = _get_ip() hostname = _get_hostname(ip_address) return SdkMetadata(version, hostname, ip_address) + + +def get_calls(classes_filter=None): + """ + Inspect the stack and retrieve an ordered list of caller functions. + + :param class_filter: If not None, only methods from that classes will be returned. + :type class: list(str) + + :return: list of callers ordered by most recent first. + :rtype: list(tuple(str, str)) + """ + return [ + inspect.getframeinfo(frame[0]).function + for frame in inspect.stack() + if classes_filter is None + or 'self' in frame[0].f_locals and frame[0].f_locals['self'].__class__.__name__ in classes_filter #pylint: disable=line-too-long + ] diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 4486931e..0f0339ad 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -42,6 +42,7 @@ def _get_storage_mock(name): client._evaluator = mocker.Mock(spec=Evaluator) client._evaluator.evaluate_treatment.return_value = { 'treatment': 'on', + 'configurations': None, 'impression': { 'label': 'some_label', 'change_number': 123 @@ -72,6 +73,70 @@ def _raise(*_): ) in impression_storage.put.mock_calls assert len(telemetry_storage.inc_latency.mock_calls) == 2 + def test_get_treatment_with_config(self, mocker): + """Test get_treatment execution paths.""" + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + telemetry_storage = mocker.Mock(spec=TelemetryStorage) + def _get_storage_mock(name): + return { + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'events': event_storage, + 'telemetry': telemetry_storage + }[name] + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = mocker.Mock(spec=SplitFactory) + factory._get_storage.side_effect = _get_storage_mock + type(factory).destroyed = destroyed_property + + mocker.patch('splitio.client.client.time.time', new=lambda: 1) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, True, None) + client._evaluator = mocker.Mock(spec=Evaluator) + client._evaluator.evaluate_treatment.return_value = { + 'treatment': 'on', + 'configurations': '{"some_config": True}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + } + } + client._logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + + assert client.get_treatment_with_config( + 'some_key', + 'some_feature' + ) == ('on', '{"some_config": True}') + assert mocker.call( + [Impression('some_key', 'some_feature', 'on', 'some_label', 123, None, 1000)] + ) in impression_storage.put.mock_calls + assert mocker.call('sdk.getTreatment', 5) in telemetry_storage.inc_latency.mock_calls + assert client._logger.mock_calls == [] + assert mocker.call( + Impression('some_key', 'some_feature', 'on', 'some_label', 123, None, 1000), + None + ) in client._send_impression_to_listener.mock_calls + + # Test with exception: + split_storage.get_change_number.return_value = -1 + def _raise(*_): + raise Exception('something') + client._evaluator.evaluate_treatment.side_effect = _raise + assert client.get_treatment_with_config('some_key', 'some_feature') == ('control', None) + assert mocker.call( + [Impression('some_key', 'some_feature', 'control', 'exception', -1, None, 1000)] + ) in impression_storage.put.mock_calls + assert len(telemetry_storage.inc_latency.mock_calls) == 2 + def test_get_treatments(self, mocker): """Test get_treatment execution paths.""" split_storage = mocker.Mock(spec=SplitStorage) @@ -102,6 +167,7 @@ def _get_storage_mock(name): client._evaluator = mocker.Mock(spec=Evaluator) client._evaluator.evaluate_treatment.return_value = { 'treatment': 'on', + 'configurations': '{"color": "red"}', 'impression': { 'label': 'some_label', 'change_number': 123 @@ -133,6 +199,73 @@ def _raise(*_): assert client.get_treatments('key', ['f1', 'f2']) == {'f1': 'control', 'f2': 'control'} assert len(telemetry_storage.inc_latency.mock_calls) == 2 + def test_get_treatments_with_config(self, mocker): + """Test get_treatment execution paths.""" + split_storage = mocker.Mock(spec=SplitStorage) + segment_storage = mocker.Mock(spec=SegmentStorage) + impression_storage = mocker.Mock(spec=ImpressionStorage) + event_storage = mocker.Mock(spec=EventStorage) + telemetry_storage = mocker.Mock(spec=TelemetryStorage) + def _get_storage_mock(name): + return { + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': impression_storage, + 'events': event_storage, + 'telemetry': telemetry_storage + }[name] + + destroyed_property = mocker.PropertyMock() + destroyed_property.return_value = False + + factory = mocker.Mock(spec=SplitFactory) + factory._get_storage.side_effect = _get_storage_mock + type(factory).destroyed = destroyed_property + + mocker.patch('splitio.client.client.time.time', new=lambda: 1) + mocker.patch('splitio.client.client.get_latency_bucket_index', new=lambda x: 5) + + client = Client(factory, True, None) + client._evaluator = mocker.Mock(spec=Evaluator) + client._evaluator.evaluate_treatment.return_value = { + 'treatment': 'on', + 'configurations': '{"color": "red"}', + 'impression': { + 'label': 'some_label', + 'change_number': 123 + } + } + client._logger = mocker.Mock() + client._send_impression_to_listener = mocker.Mock() + assert client.get_treatments_with_config('key', ['f1', 'f2']) == { + 'f1': ('on', '{"color": "red"}'), + 'f2': ('on', '{"color": "red"}') + } + + impressions_called = impression_storage.put.mock_calls[0][1][0] + assert Impression('key', 'f1', 'on', 'some_label', 123, None, 1000) in impressions_called + assert Impression('key', 'f2', 'on', 'some_label', 123, None, 1000) in impressions_called + assert mocker.call('sdk.getTreatments', 5) in telemetry_storage.inc_latency.mock_calls + assert client._logger.mock_calls == [] + assert mocker.call( + Impression('key', 'f1', 'on', 'some_label', 123, None, 1000), + None + ) in client._send_impression_to_listener.mock_calls + assert mocker.call( + Impression('key', 'f2', 'on', 'some_label', 123, None, 1000), + None + ) in client._send_impression_to_listener.mock_calls + + # Test with exception: + split_storage.get_change_number.return_value = -1 + def _raise(*_): + raise Exception('something') + client._evaluator.evaluate_treatment.side_effect = _raise + assert client.get_treatments_with_config('key', ['f1', 'f2']) == { + 'f1': ('control', None), + 'f2': ('control', None) + } + assert len(telemetry_storage.inc_latency.mock_calls) == 2 def test_destroy(self, mocker): """Test that destroy/destroyed calls are forwarded to the factory.""" diff --git a/tests/client/test_input_validator.py b/tests/client/test_input_validator.py index fe64c2a8..21c04a16 100644 --- a/tests/client/test_input_validator.py +++ b/tests/client/test_input_validator.py @@ -1,4 +1,6 @@ """Unit tests for the input_validator module.""" +#pylint: disable=protected-access,too-many-statements,no-self-use,line-too-long + from __future__ import absolute_import, division, print_function, \ unicode_literals @@ -8,10 +10,7 @@ from splitio.client.key import Key from splitio.storage import SplitStorage, EventStorage, ImpressionStorage, TelemetryStorage, \ SegmentStorage -from splitio.models.splits import Split, SplitView -from splitio.models.grammar.condition import Condition -from splitio.models.grammar.partitions import Partition -from splitio.client import input_validator +from splitio.models.splits import Split class ClientInputValidationTests(object): @@ -59,7 +58,7 @@ def _get_storage_mock(storage): ] client._logger.reset_mock() - key = ''.join('a' for _ in range(0,255)) + key = ''.join('a' for _ in range(0, 255)) assert client.get_treatment(key, 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: key too long - must be 250 characters or less.') @@ -108,72 +107,72 @@ def _get_storage_mock(storage): ] client._logger.reset_mock() - assert client.get_treatment('some_key', True) == CONTROL + assert client.get_treatment('some_key', True) == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment('some_key', []) == CONTROL + assert client.get_treatment('some_key', []) == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment('some_key', '') == CONTROL + assert client.get_treatment('some_key', '') == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: you passed an empty feature_name, feature_name must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment('some_key', 'some_feature') == 'default_treatment' + assert client.get_treatment('some_key', 'some_feature') == 'default_treatment' assert client._logger.error.mock_calls == [] assert client._logger.warning.mock_calls == [] client._logger.reset_mock() - assert client.get_treatment(Key(None, 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(None, 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: you passed a null matching_key, matching_key must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment(Key('', 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key('', 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: you passed an empty matching_key, matching_key must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment(Key(float('nan'), 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(float('nan'), 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment(Key(float('inf'), 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(float('inf'), 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment(Key(True, 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(True, 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment(Key([], 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key([], 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') ] client._logger.reset_mock() - assert client.get_treatment(Key(12345, 'bucketing_key'), 'some_feature') == 'default_treatment' + assert client.get_treatment(Key(12345, 'bucketing_key'), 'some_feature') == 'default_treatment' assert client._logger.warning.mock_calls == [ - mocker.call('get_treatment: matching_key 12345 is not of type string, ' 'converting.') + mocker.call('get_treatment: matching_key 12345 is not of type string, ' 'converting.') ] client._logger.reset_mock() - key = ''.join('a' for _ in range(0,255)) + key = ''.join('a' for _ in range(0, 255)) assert client.get_treatment(Key(key, 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ mocker.call('get_treatment: matching_key too long - must be 250 characters or less.') @@ -216,19 +215,234 @@ def _get_storage_mock(storage): ] client._logger.reset_mock() - assert client.get_treatment('mathcing_key', 'some_feature', {'test': 'test'}) =='default_treatment' + assert client.get_treatment('mathcing_key', 'some_feature', {'test': 'test'}) == 'default_treatment' assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.get_treatment('mathcing_key', 'some_feature', None) =='default_treatment' + assert client.get_treatment('mathcing_key', 'some_feature', None) == 'default_treatment' assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.get_treatment('mathcing_key', ' some_feature ', None) =='default_treatment' + assert client.get_treatment('mathcing_key', ' some_feature ', None) == 'default_treatment' assert client._logger.warning.mock_calls == [ mocker.call('get_treatment: feature_name \' some_feature \' has extra whitespace, trimming.') ] + def test_get_treatment_with_config(self, mocker): + """Test get_treatment validation.""" + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + storage_mock = mocker.Mock(spec=SplitStorage) + storage_mock.get.return_value = split_mock + + def _get_storage_mock(storage): + return { + 'splits': storage_mock, + 'segments': mocker.Mock(spec=SegmentStorage), + 'impressions': mocker.Mock(spec=ImpressionStorage), + 'events': mocker.Mock(spec=EventStorage), + 'telemetry': mocker.Mock(spec=TelemetryStorage) + }[storage] + factory_mock = mocker.Mock(spec=SplitFactory) + factory_mock._get_storage.side_effect = _get_storage_mock + factory_destroyed = mocker.PropertyMock() + factory_destroyed.return_value = False + type(factory_mock).destroyed = factory_destroyed + + client = Client(factory_mock) + client._logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) + + assert client.get_treatment_with_config(None, 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed a null key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('', 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an empty key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert client.get_treatment_with_config(key, 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: key too long - must be 250 characters or less.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(12345, 'some_feature') == ('default_treatment', '{"some": "property"}') + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment_with_config: key 12345 is not of type string, converting.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(float('nan'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(float('inf'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(True, 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config([], 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('some_key', None) == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed a null feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('some_key', 123) == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('some_key', True) == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('some_key', []) == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('some_key', '') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an empty feature_name, feature_name must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('some_key', 'some_feature') == ('default_treatment', '{"some": "property"}') + assert client._logger.error.mock_calls == [] + assert client._logger.warning.mock_calls == [] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key(None, 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed a null matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key('', 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an empty matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key(float('nan'), 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key(float('inf'), 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key(True, 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key([], 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid matching_key, matching_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key(12345, 'bucketing_key'), 'some_feature') == ('default_treatment', '{"some": "property"}') + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment_with_config: matching_key 12345 is not of type string, ' 'converting.') + ] + + client._logger.reset_mock() + key = ''.join('a' for _ in range(0, 255)) + assert client.get_treatment_with_config(Key(key, 'bucketing_key'), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: matching_key too long - must be 250 characters or less.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key('mathcing_key', None), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed a null bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key('mathcing_key', True), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key('mathcing_key', []), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an invalid bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key('mathcing_key', ''), 'some_feature') == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: you passed an empty bucketing_key, bucketing_key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config(Key('mathcing_key', 12345), 'some_feature') == ('default_treatment', '{"some": "property"}') + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment_with_config: bucketing_key 12345 is not of type string, converting.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('mathcing_key', 'some_feature', True) == (CONTROL, None) + assert client._logger.error.mock_calls == [ + mocker.call('get_treatment_with_config: attributes must be of type dictionary.') + ] + + client._logger.reset_mock() + assert client.get_treatment_with_config('mathcing_key', 'some_feature', {'test': 'test'}) == ('default_treatment', '{"some": "property"}') + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.get_treatment_with_config('mathcing_key', 'some_feature', None) == ('default_treatment', '{"some": "property"}') + assert client._logger.error.mock_calls == [] + + client._logger.reset_mock() + assert client.get_treatment_with_config('mathcing_key', ' some_feature ', None) == ('default_treatment', '{"some": "property"}') + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatment_with_config: feature_name \' some_feature \' has extra whitespace, trimming.') + ] + def test_track(self, mocker): """Test track method().""" events_storage_mock = mocker.Mock(spec=EventStorage) @@ -244,109 +458,109 @@ def test_track(self, mocker): client._logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) - assert client.track(None, "traffic_type", "event_type", 1) == False + assert client.track(None, "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed a null key, key must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("", "traffic_type", "event_type", 1) == False + assert client.track("", "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an empty key, key must be a non-empty string.") ] client._logger.reset_mock() - assert client.track(12345, "traffic_type", "event_type", 1) == True + assert client.track(12345, "traffic_type", "event_type", 1) is True assert client._logger.warning.mock_calls == [ mocker.call("track: key 12345 is not of type string, converting.") ] client._logger.reset_mock() - assert client.track(True, "traffic_type", "event_type", 1) == False + assert client.track(True, "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid key, key must be a non-empty string.") ] client._logger.reset_mock() - assert client.track([], "traffic_type", "event_type", 1) == False + assert client.track([], "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid key, key must be a non-empty string.") ] client._logger.reset_mock() - key = ''.join('a' for _ in range(0,255)) - assert client.track(key, "traffic_type", "event_type", 1) == False + key = ''.join('a' for _ in range(0, 255)) + assert client.track(key, "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: key too long - must be 250 characters or less.") ] client._logger.reset_mock() - assert client.track("some_key", None, "event_type", 1) == False + assert client.track("some_key", None, "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed a null traffic_type, traffic_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "", "event_type", 1) == False + assert client.track("some_key", "", "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an empty traffic_type, traffic_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", 12345, "event_type", 1) == False + assert client.track("some_key", 12345, "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", True, "event_type", 1) == False + assert client.track("some_key", True, "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", [], "event_type", 1) == False + assert client.track("some_key", [], "event_type", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "TRAFFIC_type", "event_type", 1) == True + assert client.track("some_key", "TRAFFIC_type", "event_type", 1) is True assert client._logger.warning.mock_calls == [ mocker.call("track: TRAFFIC_type should be all lowercase - converting string to lowercase.") ] - assert client.track("some_key", "traffic_type", None, 1) == False + assert client.track("some_key", "traffic_type", None, 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed a null event_type, event_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "", 1) == False + assert client.track("some_key", "traffic_type", "", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an empty event_type, event_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", True, 1) == False + assert client.track("some_key", "traffic_type", True, 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", [], 1) == False + assert client.track("some_key", "traffic_type", [], 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", 12345, 1) == False + assert client.track("some_key", "traffic_type", 12345, 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "@@", 1) == False + assert client.track("some_key", "traffic_type", "@@", 1) is False assert client._logger.error.mock_calls == [ mocker.call("track: you passed @@, event_type must adhere to the regular " "expression ^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$. This means " @@ -356,31 +570,31 @@ def test_track(self, mocker): ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", None) == True + assert client.track("some_key", "traffic_type", "event_type", None) is True assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", 1) == True + assert client.track("some_key", "traffic_type", "event_type", 1) is True assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", 1.23) == True + assert client.track("some_key", "traffic_type", "event_type", 1.23) is True assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", "test") == False + assert client.track("some_key", "traffic_type", "event_type", "test") is False assert client._logger.error.mock_calls == [ mocker.call("track: value must be a number.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", True) == False + assert client.track("some_key", "traffic_type", "event_type", True) is False assert client._logger.error.mock_calls == [ mocker.call("track: value must be a number.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", []) == False + assert client.track("some_key", "traffic_type", "event_type", []) is False assert client._logger.error.mock_calls == [ mocker.call("track: value must be a number.") ] @@ -419,7 +633,7 @@ def test_get_treatments(self, mocker): mocker.call('get_treatments: you passed an empty key, key must be a non-empty string.') ] - key = ''.join('a' for _ in range(0,255)) + key = ''.join('a' for _ in range(0, 255)) client._logger.reset_mock() assert client.get_treatments(key, ['some_feature']) == {'some_feature': CONTROL} assert client._logger.error.mock_calls == [ @@ -488,8 +702,114 @@ def test_get_treatments(self, mocker): mocker.call('get_treatments: feature_name \'some \' has extra whitespace, trimming.') ] + def test_get_treatments_with_config(self, mocker): + """Test getTreatments() method.""" + split_mock = mocker.Mock(spec=Split) + default_treatment_mock = mocker.PropertyMock() + default_treatment_mock.return_value = 'default_treatment' + type(split_mock).default_treatment = default_treatment_mock + conditions_mock = mocker.PropertyMock() + conditions_mock.return_value = [] + type(split_mock).conditions = conditions_mock + + storage_mock = mocker.Mock(spec=SplitStorage) + storage_mock.get.return_value = split_mock + + factory_mock = mocker.Mock(spec=SplitFactory) + factory_mock._get_storage.return_value = storage_mock + factory_destroyed = mocker.PropertyMock() + factory_destroyed.return_value = False + type(factory_mock).destroyed = factory_destroyed + def _configs(treatment): + return '{"some": "property"}' if treatment == 'default_treatment' else None + split_mock.get_configurations_for.side_effect = _configs + + client = Client(factory_mock) + client._logger = mocker.Mock() + mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) + + assert client.get_treatments_with_config(None, ['some_feature']) == {'some_feature': (CONTROL, None)} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: you passed a null key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config("", ['some_feature']) == {'some_feature': (CONTROL, None)} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: you passed an empty key, key must be a non-empty string.') + ] + + key = ''.join('a' for _ in range(0, 255)) + client._logger.reset_mock() + assert client.get_treatments_with_config(key, ['some_feature']) == {'some_feature': (CONTROL, None)} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: key too long - must be 250 characters or less.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config(12345, ['some_feature']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatments_with_config: key 12345 is not of type string, converting.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config(True, ['some_feature']) == {'some_feature': (CONTROL, None)} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config([], ['some_feature']) == {'some_feature': (CONTROL, None)} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: you passed an invalid key, key must be a non-empty string.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', None) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', True) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', 'some_string') == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', []) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', [None, None]) == {} + assert client._logger.error.mock_calls == [ + mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') + ] + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', [True]) == {} + assert mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') in client._logger.error.mock_calls + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', ['', '']) == {} + assert mocker.call('get_treatments_with_config: feature_names must be a non-empty array.') in client._logger.error.mock_calls + + client._logger.reset_mock() + assert client.get_treatments_with_config('some_key', ['some_feature ']) == {'some_feature': ('default_treatment', '{"some": "property"}')} + assert client._logger.warning.mock_calls == [ + mocker.call('get_treatments_with_config: feature_name \'some_feature \' has extra whitespace, trimming.') + ] + -class ManagerInputValidationTests(object): +class ManagerInputValidationTests(object): #pylint: disable=too-few-public-methods """Manager input validation test cases.""" def test_split_(self, mocker): @@ -507,25 +827,25 @@ def test_split_(self, mocker): manager._logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=manager._logger) - assert manager.split(None) == None + assert manager.split(None) is None assert manager._logger.error.mock_calls == [ mocker.call("split: you passed a null feature_name, feature_name must be a non-empty string.") ] manager._logger.reset_mock() - assert manager.split("") == None + assert manager.split("") is None assert manager._logger.error.mock_calls == [ mocker.call("split: you passed an empty feature_name, feature_name must be a non-empty string.") ] manager._logger.reset_mock() - assert manager.split(True) == None + assert manager.split(True) is None assert manager._logger.error.mock_calls == [ mocker.call("split: you passed an invalid feature_name, feature_name must be a non-empty string.") ] manager._logger.reset_mock() - assert manager.split([]) == None + assert manager.split([]) is None assert manager._logger.error.mock_calls == [ mocker.call("split: you passed an invalid feature_name, feature_name must be a non-empty string.") ] @@ -536,7 +856,6 @@ def test_split_(self, mocker): assert manager._logger.error.mock_calls == [] - #class TestInputSanitizationFactory(TestCase): # # def setUp(self): From 2fea2db9bbbf5c81ee7ae0e4a043af6fb034efef Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 15 Apr 2019 17:08:32 -0300 Subject: [PATCH 05/38] more tests & redis --- splitio/client/listener.py | 8 +- splitio/models/grammar/matchers/base.py | 6 +- splitio/storage/__init__.py | 17 +- splitio/storage/adapters/redis.py | 56 ++-- splitio/storage/redis.py | 101 +++--- tests/engine/test_hashfns.py | 5 +- tests/integration/files/split_changes.json | 321 ++++++++++++++++++++ tests/integration/test_redis_integration.py | 215 +++++++++++++ tests/storage/test_redis.py | 104 +++---- 9 files changed, 690 insertions(+), 143 deletions(-) create mode 100644 tests/integration/files/split_changes.json create mode 100644 tests/integration/test_redis_integration.py diff --git a/splitio/client/listener.py b/splitio/client/listener.py index 1a1fd57a..260d1969 100644 --- a/splitio/client/listener.py +++ b/splitio/client/listener.py @@ -2,6 +2,8 @@ import abc +from six import add_metaclass + class ImpressionListenerException(Exception): """Custom Exception for Impression Listener.""" @@ -51,12 +53,10 @@ def log_impression(self, impression, attributes=None): raise ImpressionListenerException('Error in log_impression user\'s' 'method is throwing exceptions') - -class ImpressionListener(object): #pylint: disable=too-few-public-methods +@add_metaclass(abc.ABCMeta) #pylint: disable=too-few-public-methods +class ImpressionListener(object): """Impression listener interface.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def log_impression(self, data): """ diff --git a/splitio/models/grammar/matchers/base.py b/splitio/models/grammar/matchers/base.py index 656c88c3..d7d818d9 100644 --- a/splitio/models/grammar/matchers/base.py +++ b/splitio/models/grammar/matchers/base.py @@ -1,13 +1,15 @@ """Abstract matcher module.""" import abc + +from six import add_metaclass + from splitio.client.key import Key +@add_metaclass(abc.ABCMeta) class Matcher(object): """Matcher abstract class.""" - __metaclass__ = abc.ABCMeta - def __init__(self, raw_matcher): """ Initialize generic data and call matcher-specific parser. diff --git a/splitio/storage/__init__.py b/splitio/storage/__init__.py index 21204cf8..1d2786a5 100644 --- a/splitio/storage/__init__.py +++ b/splitio/storage/__init__.py @@ -3,11 +3,12 @@ import abc +from six import add_metaclass + +@add_metaclass(abc.ABCMeta) class SplitStorage(object): """Split storage interface implemented as an abstract class.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def get(self, split_name): """ @@ -92,11 +93,10 @@ def get_segment_names(self): return set([name for spl in self.get_all_splits() for name in spl.get_segment_names()]) +@add_metaclass(abc.ABCMeta) class SegmentStorage(object): """Segment storage interface implemented as an abstract class.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def get(self, segment_name): """ @@ -173,11 +173,10 @@ def segment_contains(self, segment_name, key): pass +@add_metaclass(abc.ABCMeta) class ImpressionStorage(object): """Impressions storage interface.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def put(self, impressions): """ @@ -199,11 +198,10 @@ def pop_many(self, count): pass +@add_metaclass(abc.ABCMeta) class EventStorage(object): """Events storage interface.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def put(self, events): """ @@ -225,11 +223,10 @@ def pop_many(self, count): pass +@add_metaclass(abc.ABCMeta) class TelemetryStorage(object): """Telemetry storage interface.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def inc_latency(self, name, bucket): """ diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index c9345fc6..35d3547f 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -3,7 +3,7 @@ unicode_literals from builtins import str -from six import string_types, binary_type +from six import string_types, binary_type, raise_from from splitio.exceptions import SentinelConfigurationException try: @@ -47,7 +47,7 @@ def original_exception(self): return self._original_exception -class RedisAdapter(object): +class RedisAdapter(object): #pylint: disable=too-many-public-methods """ Instance decorator for Redis clients such as StrictRedis. @@ -127,11 +127,12 @@ def _remove_prefix(self, k): def keys(self, pattern): """Mimic original redis function but using user custom prefix.""" try: - return _bytes_to_string(self._remove_prefix( - self._decorated.keys(self._add_prefix(pattern)) - )) + return [ + _bytes_to_string(key) + for key in self._remove_prefix(self._decorated.keys(self._add_prefix(pattern))) + ] except RedisError as exc: - raise RedisAdapterException('Failed to execute keys operation', exc) + raise_from(RedisAdapterException('Failed to execute keys operation'), exc) def set(self, name, value, *args, **kwargs): """Mimic original redis function but using user custom prefix.""" @@ -156,26 +157,33 @@ def setex(self, name, time, value): except RedisError as exc: raise RedisAdapterException('Error executing setex operation', exc) - def delete(self, names): + def delete(self, *names): """Mimic original redis function but using user custom prefix.""" try: - return self._decorated.delete(self._add_prefix(names)) + return self._decorated.delete(*self._add_prefix(list(names))) except RedisError as exc: - raise RedisAdapterException('Error executing delete operation', exc) + raise_from(RedisAdapterException('Error executing delete operation'), exc) def exists(self, name): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.exists(self._add_prefix(name)) except RedisError as exc: - raise RedisAdapterException('Error executing exists operation', exc) + raise_from(RedisAdapterException('Error executing exists operation'), exc) + + def lrange(self, key, start, end): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.lrange(self._add_prefix(key), start, end) + except RedisError as exc: + raise_from(RedisAdapterException('Error executing exists operation'), exc) def mget(self, names): """Mimic original redis function but using user custom prefix.""" try: return _bytes_to_string(self._decorated.mget(self._add_prefix(names))) except RedisError as exc: - raise RedisAdapterException('Error executing mget operation', exc) + raise_from(RedisAdapterException('Error executing mget operation'), exc) def smembers(self, name): """Mimic original redis function but using user custom prefix.""" @@ -185,91 +193,91 @@ def smembers(self, name): for item in self._decorated.smembers(self._add_prefix(name)) ] except RedisError as exc: - raise RedisAdapterException('Error executing smembers operation', exc) + raise_from(RedisAdapterException('Error executing smembers operation'), exc) def sadd(self, name, *values): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.sadd(self._add_prefix(name), *values) except RedisError as exc: - raise RedisAdapterException('Error executing sadd operation', exc) + raise_from(RedisAdapterException('Error executing sadd operation'), exc) def srem(self, name, *values): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.srem(self._add_prefix(name), *values) except RedisError as exc: - raise RedisAdapterException('Error executing srem operation', exc) + raise_from(RedisAdapterException('Error executing srem operation'), exc) def sismember(self, name, value): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.sismember(self._add_prefix(name), value) except RedisError as exc: - raise RedisAdapterException('Error executing sismember operation', exc) + raise_from(RedisAdapterException('Error executing sismember operation'), exc) def eval(self, script, number_of_keys, *keys): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.eval(script, number_of_keys, *self._add_prefix(list(keys))) except RedisError as exc: - raise RedisAdapterException('Error executing eval operation', exc) + raise_from(RedisAdapterException('Error executing eval operation'), exc) def hset(self, name, key, value): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.hset(self._add_prefix(name), key, value) except RedisError as exc: - raise RedisAdapterException('Error executing hset operation', exc) + raise_from(RedisAdapterException('Error executing hset operation'), exc) def hget(self, name, key): """Mimic original redis function but using user custom prefix.""" try: return _bytes_to_string(self._decorated.hget(self._add_prefix(name), key)) except RedisError as exc: - raise RedisAdapterException('Error executing hget operation', exc) + raise_from(RedisAdapterException('Error executing hget operation'), exc) def incr(self, name, amount=1): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.incr(self._add_prefix(name), amount) except RedisError as exc: - raise RedisAdapterException('Error executing incr operation', exc) + raise_from(RedisAdapterException('Error executing incr operation'), exc) def getset(self, name, value): """Mimic original redis function but using user custom prefix.""" try: return _bytes_to_string(self._decorated.getset(self._add_prefix(name), value)) except RedisError as exc: - raise RedisAdapterException('Error executing getset operation', exc) + raise_from(RedisAdapterException('Error executing getset operation'), exc) def rpush(self, key, *values): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.rpush(self._add_prefix(key), *values) except RedisError as exc: - raise RedisAdapterException('Error executing rpush operation', exc) + raise_from(RedisAdapterException('Error executing rpush operation'), exc) def expire(self, key, value): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.expire(self._add_prefix(key), value) except RedisError as exc: - raise RedisAdapterException('Error executing expire operation', exc) + raise_from(RedisAdapterException('Error executing expire operation'), exc) def rpop(self, key): """Mimic original redis function but using user custom prefix.""" try: return _bytes_to_string(self._decorated.rpop(self._add_prefix(key))) except RedisError as exc: - raise RedisAdapterException('Error executing rpop operation', exc) + raise_from(RedisAdapterException('Error executing rpop operation'), exc) def ttl(self, key): """Mimic original redis function but using user custom prefix.""" try: return self._decorated.ttl(self._add_prefix(key)) except RedisError as exc: - raise RedisAdapterException('Error executing ttl operation', exc) + raise_from(RedisAdapterException('Error executing ttl operation'), exc) def _build_default_client(config): #pylint: disable=too-many-locals diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py index 59ce3d39..d62b1146 100644 --- a/splitio/storage/redis.py +++ b/splitio/storage/redis.py @@ -53,7 +53,8 @@ def get(self, split_name): raw = self._redis.get(self._get_key(split_name)) return splits.from_raw(json.loads(raw)) if raw is not None else None except RedisAdapterException: - self._logger.error('Error fetching split from storage', exc_info=True) + self._logger.error('Error fetching split from storage') + self._logger.debug('Error: ', exc_info=True) return None def put(self, split): @@ -84,9 +85,11 @@ def get_change_number(self): :rtype: int """ try: - return self._redis.get(self._SPLIT_TILL_KEY) + stored_value = self._redis.get(self._SPLIT_TILL_KEY) + return json.loads(stored_value) if stored_value is not None else None except RedisAdapterException: - self._logger.exception('Error fetching split change number from storage') + self._logger.error('Error fetching split change number from storage') + self._logger.debug('Error: ', exc_info=True) return None def set_change_number(self, new_change_number): @@ -109,7 +112,8 @@ def get_split_names(self): keys = self._redis.keys(self._get_key('*')) return [key.replace(self._get_key(''), '') for key in keys] except RedisAdapterException: - self._logger.exception('Error fetching change number from redis.') + self._logger.error('Error fetching split names from storage') + self._logger.debug('Error: ', exc_info=True) return [] def get_all_splits(self): @@ -119,17 +123,20 @@ def get_all_splits(self): :return: List of all splits in cache. :rtype: list(splitio.models.splits.Split) """ + keys = self._redis.keys(self._get_key('*')) + to_return = [] try: - keys = self._redis.keys(self._get_key('*')) - return [ - splits.from_raw(raw_split) - for raw_split in self._redis.mget(keys) - if raw_split is not None - ] + raw_splits = self._redis.mget(keys) + for raw in raw_splits: + try: + print(raw, type(raw)) + to_return.append(splits.from_raw(json.loads(raw))) + except ValueError: + self._logger.error('Could not parse split. Skipping') except RedisAdapterException: - self._logger.exception('Error when fetching all splits from redis.') - return [] - + self._logger.error('Error fetching all splits from storage') + self._logger.debug('Error: ', exc_info=True) + return to_return class RedisSegmentStorage(SegmentStorage): @@ -184,12 +191,13 @@ def get(self, segment_name): """ try: keys = (self._redis.smembers(self._get_key(segment_name))) - till = self._redis.get(self._get_till_key(segment_name)) - if keys is None or till is None: + till = self.get_change_number(segment_name) + if not keys or till is None: return None return segments.Segment(segment_name, keys, till) except RedisAdapterException: - self._logger.exception('Error fetching segment from redis.') + self._logger.error('Error fetching segment from storage') + self._logger.debug('Error: ', exc_info=True) return None def update(self, segment_name, to_add, to_remove, change_number=None): @@ -215,9 +223,12 @@ def get_change_number(self, segment_name): :rtype: int """ try: - return self._redis.get(self._get_till_key(segment_name)) + stored_value = self._redis.get(self._get_till_key(segment_name)) + print('aaa', stored_value) + return json.loads(stored_value) if stored_value is not None else None except RedisAdapterException: - self._logger.exception('Unable to fetch segment change number from redis.') + self._logger.error('Error fetching segment change number from storage') + self._logger.debug('Error: ', exc_info=True) return None def set_change_number(self, segment_name, new_change_number): @@ -255,8 +266,9 @@ def segment_contains(self, segment_name, key): try: return self._redis.sismember(self._get_key(segment_name), key) except RedisAdapterException: - self._logger.exception('Unable to test segment members in redis.') - return False + self._logger.error('Error testing members in segment stored in redis') + self._logger.debug('Error: ', exc_info=True) + return None class RedisImpressionsStorage(ImpressionStorage): @@ -272,7 +284,7 @@ def __init__(self, redis_client, sdk_metadata): :param redis_client: Redis client or compliant interface. :type redis_client: splitio.storage.adapters.redis.RedisAdapter :param sdk_metadata: SDK & Machine information. - :type sdk_metadata: dict + :type sdk_metadata: splitio.client.util.SdkMetadata """ self._redis = redis_client self._sdk_metadata = sdk_metadata @@ -293,9 +305,9 @@ def put(self, impressions): if isinstance(impression, Impression): to_store = { 'm': { # METADATA PORTION - 's': self._sdk_metadata['sdk-language-version'], - 'n': self._sdk_metadata['instance-id'], - 'i': self._sdk_metadata['ip-address'], + 's': self._sdk_metadata.sdk_version, + 'n': self._sdk_metadata.instance_name, + 'i': self._sdk_metadata.instance_ip, }, 'i': { # IMPRESSION PORTION 'k': impression.matching_key, @@ -309,13 +321,14 @@ def put(self, impressions): } bulk_impressions.append(json.dumps(to_store)) try: - inserted = self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, bulk_impressions) + inserted = self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) if inserted == len(bulk_impressions): self._logger.debug("SET EXPIRE KEY FOR QUEUE") self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) return True except RedisAdapterException: - self._logger.exception('Something went wrong when trying to add impression to redis') + self._logger.error('Something went wrong when trying to add impression to redis') + self._logger.error('Error: ', exc_info=True) return False def pop_many(self, count): @@ -340,7 +353,7 @@ def __init__(self, redis_client, sdk_metadata): :param redis_client: Redis client or compliant interface. :type redis_client: splitio.storage.adapters.redis.RedisAdapter :param sdk_metadata: SDK & Machine information. - :type sdk_metadata: dict + :type sdk_metadata: splitio.client.util.SdkMetadata """ self._redis = redis_client self._sdk_metadata = sdk_metadata @@ -367,18 +380,19 @@ def put(self, events): 'timestamp': event.timestamp }, 'm': { - 's': self._sdk_metadata['sdk-language-version'], - 'n': self._sdk_metadata['instance-id'], - 'i': self._sdk_metadata['ip-address'], + 's': self._sdk_metadata.sdk_version, + 'n': self._sdk_metadata.instance_name, + 'i': self._sdk_metadata.instance_ip, } }) for event in events ] try: - self._redis.rpush(key, to_store) + self._redis.rpush(key, *to_store) return True except RedisAdapterException: - self._logger.exception('Something went wrong when trying to add event to redis') + self._logger.error('Something went wrong when trying to add event to redis') + self._logger.debug('Error: ', exc_info=True) return False def pop_many(self, count): @@ -405,7 +419,7 @@ def __init__(self, redis_client, sdk_metadata): :param redis_client: Redis client or compliant interface. :type redis_client: splitio.storage.adapters.redis.RedisAdapter :param sdk_metadata: SDK & Machine information. - :type sdk_metadata: dict + :type sdk_metadata: splitio.client.util.SdkMetadata """ self._redis = redis_client self._metadata = sdk_metadata @@ -424,8 +438,8 @@ def _get_latency_key(self, name, bucket): :rtype: str """ return self._LATENCY_KEY_TEMPLATE.format( - sdk=self._metadata['sdk-language-version'], - instance=self._metadata['instance-id'], + sdk=self._metadata.sdk_version, + instance=self._metadata.instance_name, name=name, bucket=bucket ) @@ -441,8 +455,8 @@ def _get_counter_key(self, name): :rtype: str """ return self._COUNTER_KEY_TEMPLATE.format( - sdk=self._metadata['sdk-language-version'], - instance=self._metadata['instance-id'], + sdk=self._metadata.sdk_version, + instance=self._metadata.instance_name, name=name ) @@ -457,8 +471,8 @@ def _get_gauge_key(self, name): :rtype: str """ return self._GAUGE_KEY_TEMPLATE.format( - sdk=self._metadata['sdk-language-version'], - instance=self._metadata['instance-id'], + sdk=self._metadata.sdk_version, + instance=self._metadata.instance_name, name=name, ) @@ -479,7 +493,8 @@ def inc_latency(self, name, bucket): try: self._redis.incr(key) except RedisAdapterException: - self._logger.error("Error recording latency for metric \"%s\"", name) + self._logger.error('Something went wrong when trying to store latency in redis') + self._logger.debug('Error: ', exc_info=True) def inc_counter(self, name): """ @@ -492,7 +507,8 @@ def inc_counter(self, name): try: self._redis.incr(key) except RedisAdapterException: - self._logger.error("Error recording counter for metric \"%s\"", name) + self._logger.error('Something went wrong when trying to increment counter in redis') + self._logger.debug('Error: ', exc_info=True) def put_gauge(self, name, value): """ @@ -507,7 +523,8 @@ def put_gauge(self, name, value): try: self._redis.set(key, value) except RedisAdapterException: - self._logger.error("Error recording gauge for metric \"%s\"", name) + self._logger.error('Something went wrong when trying to set gauge in redis') + self._logger.debug('Error: ', exc_info=True) def pop_counters(self): """ diff --git a/tests/engine/test_hashfns.py b/tests/engine/test_hashfns.py index ce8ffd86..1832d31e 100644 --- a/tests/engine/test_hashfns.py +++ b/tests/engine/test_hashfns.py @@ -3,9 +3,10 @@ import io import json import os -import sys import pytest +import six + from splitio.engine import hashfns, splitters from splitio.models import splits @@ -32,7 +33,7 @@ def test_legacy_hash_ascii_data(self): assert hashfns.legacy.legacy_hash(key, seed) == hashed assert splitter.get_bucket(key, seed, splits.HashAlgorithm.LEGACY) == bucket - @pytest.mark.skipif(sys.version_info > (3, 0), reason='Should skip this on python3.') + @pytest.mark.skipif(six.PY3, reason='Should skip this on python3.') def test_legacy_hash_non_ascii_data(self): """Test legacy hash function against known results.""" splitter = splitters.Splitter() diff --git a/tests/integration/files/split_changes.json b/tests/integration/files/split_changes.json new file mode 100644 index 00000000..8a4fe839 --- /dev/null +++ b/tests/integration/files/split_changes.json @@ -0,0 +1,321 @@ +{ + "splits": [ + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "whitelist_feature", + "seed": -1222652054, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": false, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "WHITELIST", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": { + "whitelist": [ + "whitelisted_user" + ] + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + } + ] + }, + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "all_feature", + "seed": 1699838640, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": false, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "killed_feature", + "seed": -480091424, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": true, + "defaultTreatment": "defTreatment", + "configurations": { + "off": "{\"size\":15,\"test\":20}", + "defTreatment": "{\"size\":15,\"defTreatment\":true}" + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "defTreatment", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "sample_feature", + "seed": 1548363147, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": false, + "defaultTreatment": "off", + "configurations": { + "on": "{\"size\":15,\"test\":20}" + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + } + ] + }, + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "human_beigns" + }, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 30 + }, + { + "treatment": "off", + "size": 70 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "dependency_test", + "seed": 1222652054, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": false, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SPLIT_TREATMENT", + "negate": false, + "userDefinedSegmentMatcherData": null, + "dependencyMatcherData": { + "split": "all_feature", + "treatments": ["on"] + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "regex_test", + "seed": 1222652051, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": false, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "MATCHES_STRING", + "negate": false, + "userDefinedSegmentMatcherData": null, + "stringMatcherData": "abc[0-9]" + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "boolean_test", + "seed": 1222652052, + "status": "ACTIVE", + "changeNumber": 1234567, + "killed": false, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "EQUAL_TO_BOOLEAN", + "negate": false, + "userDefinedSegmentMatcherData": null, + "booleanMatcherData": true + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + } + ], + "since": -1, + "till": 1457726098069 +} diff --git a/tests/integration/test_redis_integration.py b/tests/integration/test_redis_integration.py new file mode 100644 index 00000000..9ed4ea78 --- /dev/null +++ b/tests/integration/test_redis_integration.py @@ -0,0 +1,215 @@ +"""Redis storage end to end tests.""" +#pylint: disable=no-self-use,protected-access + +import json +import os + +from splitio.client.util import get_metadata +from splitio.models import splits, segments, impressions, events, telemetry +from splitio.storage.redis import RedisSplitStorage, RedisSegmentStorage, RedisImpressionsStorage, \ + RedisEventsStorage, RedisTelemetryStorage +from splitio.storage.adapters.redis import _build_default_client + + +class SplitStorageTests(object): + """Redis Split storage e2e tests.""" + + def test_put_fetch(self): + """Test storing and retrieving splits in redis.""" + adapter = _build_default_client({}) + try: + storage = RedisSplitStorage(adapter) + with open(os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json'), 'r') as flo: + split_changes = json.load(flo) + + split_objects = [splits.from_raw(raw) for raw in split_changes['splits']] + for split_object in split_objects: + raw = split_object.to_json() + adapter.set(RedisSplitStorage._SPLIT_KEY.format(split_name=split_object.name), json.dumps(raw)) + + original_splits = {split.name: split for split in split_objects} + fetched_splits = {name: storage.get(name) for name in original_splits.keys()} + + assert set(original_splits.keys()) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + + adapter.set(RedisSplitStorage._SPLIT_TILL_KEY, split_changes['till']) + assert storage.get_change_number() == split_changes['till'] + finally: + to_delete = [ + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.all_feature", + "SPLITIO.split.killed_feature", + "SPLITIO.split.Risk_Max_Deductible", + "SPLITIO.split.whitelist_feature", + "SPLITIO.split.regex_test", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test" + ] + for item in to_delete: + adapter.delete(item) + + def test_get_all(self): + """Test get all names & splits.""" + adapter = _build_default_client({}) + try: + storage = RedisSplitStorage(adapter) + with open(os.path.join(os.path.dirname(__file__), 'files', 'split_changes.json'), 'r') as flo: + split_changes = json.load(flo) + + split_objects = [splits.from_raw(raw) for raw in split_changes['splits']] + for split_object in split_objects: + raw = split_object.to_json() + adapter.set(RedisSplitStorage._SPLIT_KEY.format(split_name=split_object.name), json.dumps(raw)) + + original_splits = {split.name: split for split in split_objects} + fetched_names = storage.get_split_names() + fetched_splits = {split.name: split for split in storage.get_all_splits()} + assert set(fetched_names) == set(fetched_splits.keys()) + + for original_split in original_splits.values(): + fetched_split = fetched_splits[original_split.name] + assert original_split.traffic_type_name == fetched_split.traffic_type_name + assert original_split.seed == fetched_split.seed + assert original_split.algo == fetched_split.algo + assert original_split.status == fetched_split.status + assert original_split.change_number == fetched_split.change_number + assert original_split.killed == fetched_split.killed + assert original_split.default_treatment == fetched_split.default_treatment + for index, original_condition in enumerate(original_split.conditions): + fetched_condition = fetched_split.conditions[index] + assert original_condition.label == fetched_condition.label + assert original_condition.condition_type == fetched_condition.condition_type + assert len(original_condition.matchers) == len(fetched_condition.matchers) + assert len(original_condition.partitions) == len(fetched_condition.partitions) + finally: + adapter.delete( + 'SPLITIO.split.sample_feature', + 'SPLITIO.splits.till', + 'SPLITIO.split.all_feature', + 'SPLITIO.split.killed_feature', + 'SPLITIO.split.Risk_Max_Deductible', + 'SPLITIO.split.whitelist_feature', + 'SPLITIO.split.regex_test', + 'SPLITIO.split.boolean_test', + 'SPLITIO.split.dependency_test' + ) + +class SegmentStorageTests(object): + """Redis Segment storage e2e tests.""" + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = _build_default_client({}) + try: + storage = RedisSegmentStorage(adapter) + adapter.sadd(storage._get_key('some_segment'), 'key1', 'key2', 'key3', 'key4') + adapter.set(storage._get_till_key('some_segment'), 123) + assert storage.segment_contains('some_segment', 'key0') is False + assert storage.segment_contains('some_segment', 'key1') is True + assert storage.segment_contains('some_segment', 'key2') is True + assert storage.segment_contains('some_segment', 'key3') is True + assert storage.segment_contains('some_segment', 'key4') is True + assert storage.segment_contains('some_segment', 'key5') is False + + fetched = storage.get('some_segment') + assert fetched.keys == set(['key1', 'key2', 'key3', 'key4']) + assert fetched.change_number == 123 + finally: + adapter.delete('SPLITIO.segment.some_segment', 'SPLITIO.segment.some_segment.till') + + +class ImpressionsStorageTests(object): + """Redis Impressions storage e2e tests.""" + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = _build_default_client({}) + try: + metadata = get_metadata() + storage = RedisImpressionsStorage(adapter, metadata) + storage.put([ + impressions.Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654), + impressions.Impression('key2', 'feature1', 'on', 'l1', 123456, 'b1', 321654), + impressions.Impression('key3', 'feature1', 'on', 'l1', 123456, 'b1', 321654) + ]) + + imps = adapter.lrange('SPLITIO.impressions', 0, 2) + assert len(imps) == 3 + finally: + adapter.delete('SPLITIO.impressions') + + +class EventsStorageTests(object): + """Redis Events storage e2e tests.""" + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = _build_default_client({}) + try: + metadata = get_metadata() + storage = RedisEventsStorage(adapter, metadata) + storage.put([ + events.Event('key1', 'user', 'purchase', 3.5, 123456), + events.Event('key2', 'user', 'purchase', 3.5, 123456), + events.Event('key3', 'user', 'purchase', 3.5, 123456) + ]) + + evts = adapter.lrange('SPLITIO.events', 0, 2) + assert len(evts) == 3 + finally: + adapter.delete('SPLITIO.events') + + +class TelemetryStorageTests(object): + """Redis Telemetry storage e2e tests.""" + + def test_put_fetch_contains(self): + """Test storing and retrieving splits in redis.""" + adapter = _build_default_client({}) + metadata = get_metadata() + storage = RedisTelemetryStorage(adapter, metadata) + try: + + storage.inc_counter('counter1') + storage.inc_counter('counter1') + storage.inc_counter('counter2') + assert adapter.get(storage._get_counter_key('counter1')) == '2' + assert adapter.get(storage._get_counter_key('counter2')) == '1' + + storage.inc_latency('latency1', 3) + storage.inc_latency('latency1', 3) + storage.inc_latency('latency2', 6) + assert adapter.get(storage._get_latency_key('latency1', 3)) == '2' + assert adapter.get(storage._get_latency_key('latency2', 6)) == '1' + + storage.put_gauge('gauge1', 3) + storage.put_gauge('gauge2', 1) + assert adapter.get(storage._get_gauge_key('gauge1')) == '3' + assert adapter.get(storage._get_gauge_key('gauge2')) == '1' + + finally: + adapter.delete( + storage._get_counter_key('counter1'), + storage._get_counter_key('counter2'), + storage._get_latency_key('latency1', 3), + storage._get_latency_key('latency2', 6), + storage._get_gauge_key('gauge1'), + storage._get_gauge_key('gauge2') + ) diff --git a/tests/storage/test_redis.py b/tests/storage/test_redis.py index d84db121..e7dfd50f 100644 --- a/tests/storage/test_redis.py +++ b/tests/storage/test_redis.py @@ -1,9 +1,11 @@ """Redis storage test module.""" #pylint: disable=no-self-use + import json + +from splitio.client.util import get_metadata from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ RedisSegmentStorage, RedisSplitStorage, RedisTelemetryStorage - from splitio.models.segments import Segment from splitio.models.impressions import Impression from splitio.models.events import Event @@ -41,7 +43,7 @@ def test_get_changenumber(self, mocker): """Test fetching changenumber.""" adapter = mocker.Mock(spec=RedisAdapter) storage = RedisSplitStorage(adapter) - adapter.get.return_value = -1 + adapter.get.return_value = '-1' assert storage.get_change_number() == -1 assert adapter.get.mock_calls == [mocker.call('SPLITIO.splits.till')] @@ -58,7 +60,7 @@ def test_get_all_splits(self, mocker): 'SPLITIO.split.split3' ] def _mget_mock(*_): - return ['{"name": split1}', '{"name": split2}', '{"name": split3}'] + return ['{"name": "split1"}', '{"name": "split2"}', '{"name": "split3"}'] adapter.mget.side_effect = _mget_mock storage.get_all_splits() @@ -69,9 +71,9 @@ def _mget_mock(*_): ] assert len(from_raw.mock_calls) == 3 - assert mocker.call('{"name": split1}') in from_raw.mock_calls - assert mocker.call('{"name": split2}') in from_raw.mock_calls - assert mocker.call('{"name": split3}') in from_raw.mock_calls + assert mocker.call({'name': 'split1'}) in from_raw.mock_calls + assert mocker.call({'name': 'split2'}) in from_raw.mock_calls + assert mocker.call({'name': 'split3'}) in from_raw.mock_calls def test_get_split_names(self, mocker): """Test getching split names.""" @@ -91,8 +93,8 @@ class RedisSegmentStorageTests(object): def test_fetch_segment(self, mocker): """Test fetching a whole segment.""" adapter = mocker.Mock(spec=RedisAdapter) - adapter.smembers.return_value = ["key1", "key2", "key3"] - adapter.get.return_value = 100 + adapter.smembers.return_value = set(["key1", "key2", "key3"]) + adapter.get.return_value = '100' from_raw = mocker.Mock() mocker.patch('splitio.models.segments.from_raw', new=from_raw) @@ -110,7 +112,7 @@ def test_fetch_segment(self, mocker): # Assert that if segment doesn't exist, None is returned adapter.reset_mock() from_raw.reset_mock() - adapter.smembers.return_value = None + adapter.smembers.return_value = set() assert storage.get('some_segment') is None assert adapter.smembers.mock_calls == [mocker.call('SPLITIO.segment.some_segment')] assert adapter.get.mock_calls == [mocker.call('SPLITIO.segment.some_segment.till')] @@ -118,7 +120,7 @@ def test_fetch_segment(self, mocker): def test_fetch_change_number(self, mocker): """Test fetching change number.""" adapter = mocker.Mock(spec=RedisAdapter) - adapter.get.return_value = 100 + adapter.get.return_value = '100' storage = RedisSegmentStorage(adapter) result = storage.get_change_number('some_segment') @@ -142,12 +144,8 @@ class RedisImpressionsStorageTests(object): #pylint: disable=too-few-public-met def test_add_impressions(self, mocker): """Test that adding impressions to storage works.""" adapter = mocker.Mock(spec=RedisAdapter) - sdk_metadata = { - 'sdk-language-version': 'python-1.2.3', - 'instance-id': 'some_instance_id', - 'ip-address': '123.123.123.123' - } - storage = RedisImpressionsStorage(adapter, sdk_metadata) + metadata = get_metadata() + storage = RedisImpressionsStorage(adapter, metadata) impressions = [ Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), @@ -160,9 +158,9 @@ def test_add_impressions(self, mocker): to_validate = [json.dumps({ 'm': { # METADATA PORTION - 's': sdk_metadata['sdk-language-version'], - 'n': sdk_metadata['instance-id'], - 'i': sdk_metadata['ip-address'], + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, }, 'i': { # IMPRESSION PORTION 'k': impression.matching_key, @@ -175,7 +173,7 @@ def test_add_impressions(self, mocker): } }) for impression in impressions] - assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', to_validate)] + assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', *to_validate)] # Assert that if an exception is thrown it's caught and False is returned adapter.reset_mock() @@ -191,12 +189,9 @@ class RedisEventsStorageTests(object): #pylint: disable=too-few-public-methods def test_add_events(self, mocker): """Test that adding impressions to storage works.""" adapter = mocker.Mock(spec=RedisAdapter) - sdk_metadata = { - 'sdk-language-version': 'python-1.2.3', - 'instance-id': 'some_instance_id', - 'ip-address': '123.123.123.123' - } - storage = RedisEventsStorage(adapter, sdk_metadata) + metadata = get_metadata() + + storage = RedisEventsStorage(adapter, metadata) events = [ Event('key1', 'user', 'purchase', 10, 123456), @@ -208,9 +203,9 @@ def test_add_events(self, mocker): list_of_raw_events = [json.dumps({ 'm': { # METADATA PORTION - 's': sdk_metadata['sdk-language-version'], - 'n': sdk_metadata['instance-id'], - 'i': sdk_metadata['ip-address'], + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, }, 'i': { # IMPRESSION PORTION 'key': event.key, @@ -222,7 +217,7 @@ def test_add_events(self, mocker): }) for event in events] # To deal with python2 & 3 differences in hashing/order when dumping json. - list_of_raw_json_strings_called = adapter.rpush.mock_calls[0][1][1] + list_of_raw_json_strings_called = adapter.rpush.mock_calls[0][1][1:] list_of_events_called = [json.loads(event) for event in list_of_raw_json_strings_called] list_of_events_sent = [json.loads(event) for event in list_of_raw_events] for item in list_of_events_sent: @@ -243,58 +238,49 @@ class RedisTelemetryStorageTests(object): def test_inc_latency(self, mocker): """Test incrementing latency.""" adapter = mocker.Mock(spec=RedisAdapter) - sdk_metadata = { - 'sdk-language-version': 'python-1.2.3', - 'instance-id': 'some_instance_id', - 'ip-address': '123.123.123.123' - } - storage = RedisTelemetryStorage(adapter, sdk_metadata) + metadata = get_metadata() + + storage = RedisTelemetryStorage(adapter, metadata) storage.inc_latency('some_latency', 0) storage.inc_latency('some_latency', 1) storage.inc_latency('some_latency', 5) storage.inc_latency('some_latency', 5) storage.inc_latency('some_latency', 22) assert adapter.incr.mock_calls == [ - mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.0'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.1'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.5'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/latency.some_latency.bucket.5') + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/latency.some_latency.bucket.0'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/latency.some_latency.bucket.1'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/latency.some_latency.bucket.5'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/latency.some_latency.bucket.5') ] def test_inc_counter(self, mocker): """Test incrementing latency.""" adapter = mocker.Mock(spec=RedisAdapter) - sdk_metadata = { - 'sdk-language-version': 'python-1.2.3', - 'instance-id': 'some_instance_id', - 'ip-address': '123.123.123.123' - } - storage = RedisTelemetryStorage(adapter, sdk_metadata) + metadata = get_metadata() + + storage = RedisTelemetryStorage(adapter, metadata) storage.inc_counter('some_counter_1') storage.inc_counter('some_counter_1') storage.inc_counter('some_counter_1') storage.inc_counter('some_counter_2') storage.inc_counter('some_counter_2') assert adapter.incr.mock_calls == [ - mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_1'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_1'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_1'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_2'), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/count.some_counter_2') + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/count.some_counter_1'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/count.some_counter_1'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/count.some_counter_1'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/count.some_counter_2'), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/count.some_counter_2') ] def test_inc_gauge(self, mocker): """Test incrementing latency.""" adapter = mocker.Mock(spec=RedisAdapter) - sdk_metadata = { - 'sdk-language-version': 'python-1.2.3', - 'instance-id': 'some_instance_id', - 'ip-address': '123.123.123.123' - } - storage = RedisTelemetryStorage(adapter, sdk_metadata) + metadata = get_metadata() + + storage = RedisTelemetryStorage(adapter, metadata) storage.put_gauge('gauge1', 123) storage.put_gauge('gauge2', 456) assert adapter.set.mock_calls == [ - mocker.call('SPLITIO/python-1.2.3/some_instance_id/gauge.gauge1', 123), - mocker.call('SPLITIO/python-1.2.3/some_instance_id/gauge.gauge2', 456) + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/gauge.gauge1', 123), + mocker.call('SPLITIO/' + metadata.sdk_version + '/' + metadata.instance_name + '/gauge.gauge2', 456) ] From 7dd4d03dda9faef5cf4e669958e09064cc0d7f1b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 15 Apr 2019 17:45:28 -0300 Subject: [PATCH 06/38] reorder tasks --- splitio/client/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 4283866c..1d5164bb 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -255,9 +255,9 @@ def _build_in_memory_factory(api_key, config, sdk_url=None, events_url=None): # } # Start tasks that have no dependencies + tasks['splits'].start() tasks['impressions'].start() tasks['events'].start() - tasks['splits'].start() tasks['telemetry'].start() def split_ready_task(): From a37bdeec9a87cd034fc4ef153eaeaba403e5588b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 16 Apr 2019 12:21:05 -0300 Subject: [PATCH 07/38] add localhost support for dynamic configs --- setup.py | 1 + splitio/api/impressions.py | 6 +- splitio/client/client.py | 33 +++++- splitio/client/factory.py | 1 + splitio/client/localhost.py | 185 ++++++++++++++++++++++----------- tests/client/files/file1.split | 14 +++ tests/client/files/file2.yaml | 18 ++++ tests/client/test_localhost.py | 142 +++++++++++++++++++++++++ 8 files changed, 335 insertions(+), 65 deletions(-) create mode 100644 tests/client/files/file1.split create mode 100644 tests/client/files/file2.yaml create mode 100644 tests/client/test_localhost.py diff --git a/setup.py b/setup.py index 0d497308..2525428a 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ tests_require = ['flake8', 'pytest', 'pytest-mock', 'coverage', 'pytest-cov'] install_requires = [ 'requests>=2.9.1', + 'pyyaml>=5.1', 'future>=0.15.2', 'docopt>=0.6.2', ] diff --git a/splitio/api/impressions.py b/splitio/api/impressions.py index 95c625c0..b608c8de 100644 --- a/splitio/api/impressions.py +++ b/splitio/api/impressions.py @@ -40,7 +40,7 @@ def _build_bulk(impressions): """ return [ { - 'testName': group[0], + 'testName': test_name, 'keyImpressions': [ { 'keyName': impression.matching_key, @@ -50,10 +50,10 @@ def _build_bulk(impressions): 'label': impression.label, 'bucketingKey': impression.bucketing_key } - for impression in group[1] + for impression in imps ] } - for group in groupby( + for (test_name, imps) in groupby( sorted(impressions, key=lambda i: i.feature_name), lambda i: i.feature_name ) diff --git a/splitio/client/client.py b/splitio/client/client.py index 54aa2307..1070024b 100644 --- a/splitio/client/client.py +++ b/splitio/client/client.py @@ -81,6 +81,21 @@ def _send_impression_to_listener(self, impression, attributes): self._logger.debug('Error', exc_info=True) def get_treatment_with_config(self, key, feature, attributes=None): + """ + Get the treatment and config for a feature and key, with optional dictionary of attributes. + + This method never raises an exception. If there's a problem, the appropriate log message + will be generated and the method will return the CONTROL treatment. + + :param key: The key for which to get the treatment + :type key: str + :param feature: The name of the feature for which to get the treatment + :type feature: str + :param attributes: An optional dictionary of attributes + :type attributes: dict + :return: The treatment for the key and feature + :rtype: tuple(str, str) + """ try: if self.destroyed: self._logger.error("Client has already been destroyed - no calls possible") @@ -157,7 +172,7 @@ def get_treatment(self, key, feature, attributes=None): def get_treatments_with_config(self, key, features, attributes=None): """ - Evaluate multiple features and return a dictionary with all the feature/treatments. + Evaluate multiple features and return a dict with feature -> (treatment, config). Get the treatments for a list of features considering a key, with an optional dictionary of attributes. This method never raises an exception. If there's a problem, the appropriate @@ -235,7 +250,21 @@ def get_treatments_with_config(self, key, features, attributes=None): def get_treatments(self, key, features, attributes=None): - """TODO""" + """ + Evaluate multiple features and return a dictionary with all the feature/treatments. + + Get the treatments for a list of features considering a key, with an optional dictionary of + attributes. This method never raises an exception. If there's a problem, the appropriate + log message will be generated and the method will return the CONTROL treatment. + :param key: The key for which to get the treatment + :type key: str + :param features: Array of the names of the features for which to get the treatment + :type feature: list + :param attributes: An optional dictionary of attributes + :type attributes: dict + :return: Dictionary with the result of all the features provided + :rtype: dict + """ with_config = self.get_treatments_with_config(key, features, attributes) return {feature: result[0] for (feature, result) in six.iteritems(with_config)} diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 4283866c..4c489e43 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -334,6 +334,7 @@ def _build_localhost_factory(config): tasks = {'splits': LocalhostSplitSynchronizationTask( cfg['splitFile'], storages['splits'], + cfg['featuresRefreshRate'], ready_event )} tasks['splits'].start() diff --git a/splitio/client/localhost.py b/splitio/client/localhost.py index 9058777e..28bad3e0 100644 --- a/splitio/client/localhost.py +++ b/splitio/client/localhost.py @@ -1,14 +1,19 @@ """Localhost client mocked components.""" +import itertools import logging import re -from splitio.models.splits import from_raw + +from six import raise_from +import yaml + +from splitio.models import splits from splitio.storage import ImpressionStorage, EventStorage, TelemetryStorage from splitio.tasks import BaseSynchronizationTask from splitio.tasks.util import asynctask -_COMMENT_LINE_RE = re.compile('^#.*$') -_DEFINITION_LINE_RE = re.compile('^(?[\w_-]+)\s+(?P[\w_-]+)$') +_LEGACY_COMMENT_LINE_RE = re.compile(r'^#.*$') +_LEGACY_DEFINITION_LINE_RE = re.compile(r'^(?[\w_-]+)\s+(?P[\w_-]+)$') _LOGGER = logging.getLogger(__name__) @@ -69,7 +74,7 @@ def pop_gauges(self, *_, **__): #pylint: disable=arguments-differ class LocalhostSplitSynchronizationTask(BaseSynchronizationTask): """Split synchronization task that periodically checks the file and updated the splits.""" - def __init__(self, filename, storage, ready_event): + def __init__(self, filename, storage, period, ready_event): """ Class constructor. @@ -83,7 +88,8 @@ def __init__(self, filename, storage, ready_event): self._filename = filename self._ready_event = ready_event self._storage = storage - self._task = asynctask.AsyncTask(self._update_splits, 5, self._on_start) + self._period = period + self._task = asynctask.AsyncTask(self._update_splits, period, self._on_start) def _on_start(self): """Sync splits and set event if successful.""" @@ -91,14 +97,14 @@ def _on_start(self): self._ready_event.set() @staticmethod - def _make_all_keys_based_split(split_name, treatment): + def _make_split(split_name, conditions, configs=None): """ Make a split with a single all_keys matcher. :param split_name: Name of the split. :type split_name: str. """ - return from_raw({ + return splits.from_raw({ 'changeNumber': 123, 'trafficTypeName': 'user', 'name': split_name, @@ -107,30 +113,55 @@ def _make_all_keys_based_split(split_name, treatment): 'seed': 321654, 'status': 'ACTIVE', 'killed': False, - 'defaultTreatment': treatment, + 'defaultTreatment': 'control', 'algo': 2, - 'conditions': [ - { - 'partitions': [ - {'treatment': treatment, 'size': 100} - ], - 'contitionType': 'WHITELIST', - 'label': 'some_other_label', - 'matcherGroup': { - 'matchers': [ - { - 'matcherType': 'ALL_KEYS', - 'negate': False, - } - ], - 'combiner': 'AND' - } - } - ] + 'conditions': conditions, + 'configurations': configs }) + @staticmethod + def _make_all_keys_condition(treatment): + return { + 'partitions': [ + {'treatment': treatment, 'size': 100} + ], + 'conditionType': 'WHITELIST', + 'label': 'some_other_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'ALL_KEYS', + 'negate': False, + } + ], + 'combiner': 'AND' + } + } + + @staticmethod + def _make_whitelist_condition(whitelist, treatment): + return { + 'partitions': [ + {'treatment': treatment, 'size': 100} + ], + 'conditionType': 'WHITELIST', + 'label': 'some_other_label', + 'matcherGroup': { + 'matchers': [ + { + 'matcherType': 'WHITELIST', + 'negate': False, + 'whitelistMatcherData': { + 'whitelist': whitelist + } + } + ], + 'combiner': 'AND' + } + } + @classmethod - def _read_splits_from_file(cls, filename): + def _read_splits_from_legacy_file(cls, filename): """ Parse a splits file and return a populated storage. @@ -140,53 +171,89 @@ def _read_splits_from_file(cls, filename): :return: Storage populataed with splits ready to be evaluated. :rtype: InMemorySplitStorage """ - splits = {} + to_return = {} try: with open(filename, 'r') as flo: for line in flo: - if line.strip() == '': - continue - - comment_match = _COMMENT_LINE_RE.match(line) - if comment_match: + if line.strip() == '' or _LEGACY_COMMENT_LINE_RE.match(line): continue - definition_match = _DEFINITION_LINE_RE.match(line) - if definition_match: - splits[definition_match.group('feature')] = cls._make_all_keys_based_split( - definition_match.group('feature'), - definition_match.group('treatment') + definition_match = _LEGACY_DEFINITION_LINE_RE.match(line) + if not definition_match: + _LOGGER.warning( + 'Invalid line on localhost environment split ' + 'definition. Line = %s', + line ) continue - _LOGGER.warning( - 'Invalid line on localhost environment split ' - 'definition. Line = %s', - line - ) - return splits - except IOError as e: - raise ValueError("Error parsing split file") - # TODO: ver raise from! -# raise_from(ValueError( -# 'There was a problem with ' -# 'the splits definition file "{}"'.format(filename)), -# e -# ) + cond = cls._make_all_keys_condition(definition_match.group('treatment')) + splt = cls._make_split(definition_match.group('feature'), [cond]) + to_return[splt.name] = splt + return to_return + except IOError as exc: + raise_from( + ValueError("Error parsing file %s. Make sure it's readable." % filename), + exc + ) + + @classmethod + def _read_splits_from_yaml_file(cls, filename): + """ + Parse a splits file and return a populated storage. + + :param filename: Path of the file containing mocked splits & treatments. + :type filename: str. + + :return: Storage populataed with splits ready to be evaluated. + :rtype: InMemorySplitStorage + """ + try: + with open(filename, 'r') as flo: + parsed = yaml.load(flo.read(), Loader=yaml.FullLoader) + + grouped_by_feature_name = itertools.groupby( + sorted(parsed, key=lambda i: next(iter(i.keys()))), + lambda i: next(iter(i.keys()))) + + to_return = {} + for (split_name, statements) in grouped_by_feature_name: + configs = {} + whitelist = [] + all_keys = [] + for statement in statements: + data = next(iter(statement.values())) # grab the first (and only) value. + if 'keys' in data: + keys = data['keys'] if isinstance(data['keys'], list) else [data['keys']] + whitelist.append(cls._make_whitelist_condition(keys, data['treatment'])) + else: + all_keys.append(cls._make_all_keys_condition(data['treatment'])) + if 'config' in data: + configs[data['treatment']] = data['config'] + to_return[split_name] = cls._make_split(split_name, whitelist + all_keys, configs) + return to_return + + except IOError as exc: + raise_from( + ValueError("Error parsing file %s. Make sure it's readable." % filename), + exc + ) def _update_splits(self): """Update splits in storage.""" _LOGGER.info('Synchronizing splits now.') - splits = self._read_splits_from_file(self._filename) - to_delete = [name for name in self._storage.get_split_names() if name not in splits.keys()] - for split in splits.values(): + if self._filename.split('.')[-1].lower() in ('yaml', 'yml'): + fetched = self._read_splits_from_yaml_file(self._filename) + else: + fetched = self._read_splits_from_legacy_file(self._filename) + to_delete = [name for name in self._storage.get_split_names() if name not in fetched.keys()] + for split in fetched.values(): self._storage.put(split) for split in to_delete: self._storage.remove(split) - def is_running(self): """Return whether the task is running.""" return self._task.running @@ -195,13 +262,11 @@ def start(self): """Start split synchronization.""" self._task.start() - def stop(self, stop_event): + def stop(self, event=None): """ Stop task. :param stop_event: Event top set when the task finishes. :type stop_event: threading.Event. """ - self._task.stop(stop_event) - - + self._task.stop(event) diff --git a/tests/client/files/file1.split b/tests/client/files/file1.split new file mode 100644 index 00000000..10af3644 --- /dev/null +++ b/tests/client/files/file1.split @@ -0,0 +1,14 @@ +events_write_es on +events_routing sqs +impressions_routing sqs +workspaces_v1 on +create_org_with_workspace on +sqs_events_processing on +sqs_impressions_processing on +sqs_events_fetch on +sqs_impressions_fetch off +sqs_impressions_fetch_period 700 +sqs_impressions_fetch_threads 10 +sqs_events_fetch_period 500 +sqs_events_fetch_threads 5 + diff --git a/tests/client/files/file2.yaml b/tests/client/files/file2.yaml new file mode 100644 index 00000000..bc9b7705 --- /dev/null +++ b/tests/client/files/file2.yaml @@ -0,0 +1,18 @@ +- my_feature: + treatment: "on" + keys: "key" + config: "{\"desc\" : \"this applies only to ON treatment\"}" +- other_feature_3: + treatment: "off" +- my_feature: + treatment: "off" + keys: "only_key" + config: "{\"desc\" : \"this applies only to OFF and only for only_key. The rest will receive ON\"}" +- other_feature_3: + treatment: "on" + keys: "key_whitelist" +- other_feature: + treatment: "on" + keys: ["key2","key3"] +- other_feature_2: + treatment: "on" diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py new file mode 100644 index 00000000..aed7b018 --- /dev/null +++ b/tests/client/test_localhost.py @@ -0,0 +1,142 @@ +"""Localhost mode test module.""" +#pylint: disable=no-self-use,line-too-long,protected-access + +import os +import tempfile + +from splitio.client import localhost +from splitio.models.splits import Split +from splitio.models.grammar.matchers import AllKeysMatcher + + +class LocalHostStoragesTests(object): + """Localhost storages test cases.""" + + def test_dummy_impression_storage(self): + """Test that dummy impression storage never complains.""" + imp_storage = localhost.LocalhostImpressionsStorage() + assert imp_storage.put() is None + assert imp_storage.put('ads') is None + assert imp_storage.put(3) is None + assert imp_storage.put([2]) is None + assert imp_storage.put(object) is None + assert imp_storage.pop_many() is None + assert imp_storage.pop_many('ads') is None + assert imp_storage.pop_many(3) is None + assert imp_storage.pop_many([2]) is None + assert imp_storage.pop_many(object) is None + + def test_dummy_event_storage(self): + """Test that dummy event storage never complains.""" + evt_storage = localhost.LocalhostEventsStorage() + assert evt_storage.put() is None + assert evt_storage.put('ads') is None + assert evt_storage.put(3) is None + assert evt_storage.put([2]) is None + assert evt_storage.put(object) is None + assert evt_storage.pop_many() is None + assert evt_storage.pop_many('ads') is None + assert evt_storage.pop_many(3) is None + assert evt_storage.pop_many([2]) is None + assert evt_storage.pop_many(object) is None + + def test_dummy_telemetry_storage(self): + """Test that dummy telemetry storage never complains.""" + telemetry_storage = localhost.LocalhostTelemetryStorage() + assert telemetry_storage.inc_latency() is None + assert telemetry_storage.inc_latency('ads') is None + assert telemetry_storage.inc_latency(3) is None + assert telemetry_storage.inc_latency([2]) is None + assert telemetry_storage.inc_latency(object) is None + assert telemetry_storage.pop_latencies() is None + assert telemetry_storage.pop_latencies('ads') is None + assert telemetry_storage.pop_latencies(3) is None + assert telemetry_storage.pop_latencies([2]) is None + assert telemetry_storage.pop_latencies(object) is None + assert telemetry_storage.inc_counter() is None + assert telemetry_storage.inc_counter('ads') is None + assert telemetry_storage.inc_counter(3) is None + assert telemetry_storage.inc_counter([2]) is None + assert telemetry_storage.inc_counter(object) is None + assert telemetry_storage.pop_counters() is None + assert telemetry_storage.pop_counters('ads') is None + assert telemetry_storage.pop_counters(3) is None + assert telemetry_storage.pop_counters([2]) is None + assert telemetry_storage.pop_counters(object) is None + assert telemetry_storage.put_gauge() is None + assert telemetry_storage.put_gauge('ads') is None + assert telemetry_storage.put_gauge(3) is None + assert telemetry_storage.put_gauge([2]) is None + assert telemetry_storage.put_gauge(object) is None + assert telemetry_storage.pop_gauges() is None + assert telemetry_storage.pop_gauges('ads') is None + assert telemetry_storage.pop_gauges(3) is None + assert telemetry_storage.pop_gauges([2]) is None + assert telemetry_storage.pop_gauges(object) is None + + +class SplitFetchingTaskTests(object): + """Localhost split fetching task test cases.""" + + def test_make_all_keys_condition(self): + """Test all keys-based condition construction.""" + cond = localhost.LocalhostSplitSynchronizationTask._make_all_keys_condition('on') + assert cond['conditionType'] == 'WHITELIST' + assert len(cond['partitions']) == 1 + assert cond['partitions'][0]['treatment'] == 'on' + assert cond['partitions'][0]['size'] == 100 + assert len(cond['matcherGroup']['matchers']) == 1 + assert cond['matcherGroup']['matchers'][0]['matcherType'] == 'ALL_KEYS' + assert cond['matcherGroup']['matchers'][0]['negate'] is False + assert cond['matcherGroup']['combiner'] == 'AND' + + def test_make_whitelist_condition(self): + """Test whitelist-based condition construction.""" + cond = localhost.LocalhostSplitSynchronizationTask._make_whitelist_condition(['key1', 'key2'], 'on') + assert cond['conditionType'] == 'WHITELIST' + assert len(cond['partitions']) == 1 + assert cond['partitions'][0]['treatment'] == 'on' + assert cond['partitions'][0]['size'] == 100 + assert len(cond['matcherGroup']['matchers']) == 1 + assert cond['matcherGroup']['matchers'][0]['matcherType'] == 'WHITELIST' + assert cond['matcherGroup']['matchers'][0]['whitelistMatcherData']['whitelist'] == ['key1', 'key2'] + assert cond['matcherGroup']['matchers'][0]['negate'] is False + assert cond['matcherGroup']['combiner'] == 'AND' + + def test_parse_legacy_file(self): + """Test that aprsing a legacy file works.""" + with tempfile.NamedTemporaryFile() as temp_flo: + temp_flo.write('split1 on\n') + temp_flo.write('split2 off\n') + temp_flo.flush() + splits = localhost.LocalhostSplitSynchronizationTask._read_splits_from_legacy_file(temp_flo.name) + assert len(splits) == 2 + for split in splits.values(): + assert isinstance(split, Split) + assert splits['split1'].name == 'split1' + assert splits['split2'].name == 'split2' + assert isinstance(splits['split1'].conditions[0].matchers[0], AllKeysMatcher) + assert isinstance(splits['split2'].conditions[0].matchers[0], AllKeysMatcher) + + def test_parse_yaml_file(self): + """Test that parsing a yaml file works.""" + filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') + splits = localhost.LocalhostSplitSynchronizationTask._read_splits_from_yaml_file(filename) + assert len(splits) == 4 + for split in splits.values(): + assert isinstance(split, Split) + assert splits['my_feature'].name == 'my_feature' + assert splits['other_feature'].name == 'other_feature' + assert splits['other_feature_2'].name == 'other_feature_2' + assert splits['other_feature_3'].name == 'other_feature_3' + + # test that all_keys conditions are pushed to the bottom so that they don't override + # whitelists + condition_types = [ + [cond.matchers[0].__class__.__name__ for cond in split.conditions] + for split in splits.values() + ] + assert all( + 'WhitelistMatcher' not in c[c.index('AllKeysMatcher'):] if 'AllKeysMatcher' in c else True + for c in condition_types + ) From 48ab64c1ec745dcfb219a25e8ce81fbe2aefd6e6 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 16 Apr 2019 17:41:08 -0300 Subject: [PATCH 08/38] improve logging --- splitio/api/__init__.py | 9 +-------- splitio/api/client.py | 24 +++++++----------------- splitio/api/events.py | 8 ++++++-- splitio/api/impressions.py | 8 ++++++-- splitio/api/segments.py | 8 ++++++-- splitio/api/splits.py | 8 ++++++-- splitio/api/telemetry.py | 13 +++++++++---- tests/api/test_events.py | 2 +- tests/api/test_impressions_api.py | 2 +- tests/api/test_segments_api.py | 2 +- tests/api/test_splits_api.py | 2 +- tests/api/test_telemetry.py | 6 +++--- 12 files changed, 48 insertions(+), 44 deletions(-) diff --git a/splitio/api/__init__.py b/splitio/api/__init__.py index 38a7d568..a9fa4f6f 100644 --- a/splitio/api/__init__.py +++ b/splitio/api/__init__.py @@ -3,19 +3,12 @@ class APIException(Exception): """Exception to raise when an API call fails.""" - def __init__(self, custom_message, status_code=None, original_exception=None): + def __init__(self, custom_message, status_code=None): """Constructor.""" Exception.__init__(self, custom_message) self._status_code = status_code if status_code else -1 - self._custom_message = custom_message - self._original_exception = original_exception @property def status_code(self): """Return HTTP status code.""" return self._status_code - - @property - def custom_message(self): - """Return custom message.""" - return self._custom_message diff --git a/splitio/api/client.py b/splitio/api/client.py index 3a9e7ca8..00fd2ede 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -2,6 +2,8 @@ from __future__ import division from collections import namedtuple + +from six import raise_from import requests HttpResponse = namedtuple('HttpResponse', ['status_code', 'body']) @@ -10,7 +12,7 @@ class HttpClientException(Exception): """HTTP Client exception.""" - def __init__(self, custom_message, original_exception=None): + def __init__(self, custom_message): """ Class constructor. @@ -20,18 +22,6 @@ def __init__(self, custom_message, original_exception=None): :type original_exception: Exception. """ Exception.__init__(self, custom_message) - self._custom_message = custom_message - self._original_exception = original_exception - - @property - def custom_message(self): - """Return custom message.""" - return self._custom_message - - @property - def original_exception(self): - """Return original exception.""" - return self._original_exception class HttpClient(object): @@ -101,8 +91,8 @@ def get(self, server, path, apikey, query=None, extra_headers=None): #pylint: d timeout=self._timeout ) return HttpResponse(response.status_code, response.text) - except Exception as exc: - raise HttpClientException('requests library is throwing exceptions', exc) + except Exception as exc: #pylint: disable=broad-except + raise_from(HttpClientException('requests library is throwing exceptions'), exc) def post(self, server, path, apikey, body, query=None, extra_headers=None): #pylint: disable=too-many-arguments """ @@ -137,5 +127,5 @@ def post(self, server, path, apikey, body, query=None, extra_headers=None): #py timeout=self._timeout ) return HttpResponse(response.status_code, response.text) - except Exception as exc: - raise HttpClientException('requests library is throwing exceptions', exc) + except Exception as exc: #pylint: disable=broad-except + raise_from(HttpClientException('requests library is throwing exceptions'), exc) diff --git a/splitio/api/events.py b/splitio/api/events.py index e2af3d00..bcb4a73b 100644 --- a/splitio/api/events.py +++ b/splitio/api/events.py @@ -1,5 +1,8 @@ """Events API module.""" import logging + +from six import raise_from + from splitio.api import APIException from splitio.api.client import HttpClientException @@ -69,5 +72,6 @@ def flush_events(self, events): if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) except HttpClientException as exc: - self._logger.debug('Error flushing events: ', exc_info=True) - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + raise_from(APIException('Events not flushed properly.'), exc) diff --git a/splitio/api/impressions.py b/splitio/api/impressions.py index 95c625c0..5c66e4be 100644 --- a/splitio/api/impressions.py +++ b/splitio/api/impressions.py @@ -2,6 +2,9 @@ import logging from itertools import groupby + +from six import raise_from + from splitio.api import APIException from splitio.api.client import HttpClientException @@ -78,5 +81,6 @@ def flush_impressions(self, impressions): if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) except HttpClientException as exc: - self._logger.debug('Error flushing events: ', exc_info=True) - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + raise_from(APIException('Impressions not flushed properly.'), exc) diff --git a/splitio/api/segments.py b/splitio/api/segments.py index c6533827..7cce297b 100644 --- a/splitio/api/segments.py +++ b/splitio/api/segments.py @@ -2,6 +2,9 @@ import json import logging + +from six import raise_from + from splitio.api import APIException from splitio.api.client import HttpClientException @@ -47,5 +50,6 @@ def fetch_segment(self, segment_name, change_number): else: raise APIException(response.body, response.status_code) except HttpClientException as exc: - self._logger.debug('Error flushing events: ', exc_info=True) - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + raise_from(APIException('Segments not fetched properly.'), exc) diff --git a/splitio/api/splits.py b/splitio/api/splits.py index aca15b59..9c3af15a 100644 --- a/splitio/api/splits.py +++ b/splitio/api/splits.py @@ -2,6 +2,9 @@ import logging import json + +from six import raise_from + from splitio.api import APIException from splitio.api.client import HttpClientException @@ -44,5 +47,6 @@ def fetch_splits(self, change_number): else: raise APIException(response.body, response.status_code) except HttpClientException as exc: - self._logger.debug('Error flushing events: ', exc_info=True) - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + raise_from(APIException('Splits not fetched correctly.'), exc) diff --git a/splitio/api/telemetry.py b/splitio/api/telemetry.py index a6f88700..c333b11d 100644 --- a/splitio/api/telemetry.py +++ b/splitio/api/telemetry.py @@ -60,7 +60,9 @@ def flush_latencies(self, latencies): if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) except HttpClientException as exc: - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + six.raise_from(APIException('Latencies not flushed correctly.'), exc) @staticmethod def _build_gauges(gauges): @@ -94,7 +96,9 @@ def flush_gauges(self, gauges): if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) except HttpClientException as exc: - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + six.raise_from(APIException('Gauges not flushed correctly.'), exc) @staticmethod def _build_counters(counters): @@ -128,5 +132,6 @@ def flush_counters(self, counters): if not 200 <= response.status_code < 300: raise APIException(response.body, response.status_code) except HttpClientException as exc: - self._logger.debug('Error flushing events: ', exc_info=True) - raise APIException(exc.custom_message, original_exception=exc.original_exception) + self._logger.error('Http client is throwing exceptions') + self._logger.debug('Error: ', exc_info=True) + six.raise_from(APIException('Counters not flushed correctly.'), exc) diff --git a/tests/api/test_events.py b/tests/api/test_events.py index 0947c4ed..cdcc67f5 100644 --- a/tests/api/test_events.py +++ b/tests/api/test_events.py @@ -44,7 +44,7 @@ def test_post_events(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = events_api.flush_events([ diff --git a/tests/api/test_impressions_api.py b/tests/api/test_impressions_api.py index 9bf5467e..278030b3 100644 --- a/tests/api/test_impressions_api.py +++ b/tests/api/test_impressions_api.py @@ -51,7 +51,7 @@ def test_post_impressions(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = impressions_api.flush_impressions([ diff --git a/tests/api/test_segments_api.py b/tests/api/test_segments_api.py index 0344fd23..851d8359 100644 --- a/tests/api/test_segments_api.py +++ b/tests/api/test_segments_api.py @@ -19,7 +19,7 @@ def test_fetch_segment_changes(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.get.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = segment_api.fetch_segment('some_segment', 123) diff --git a/tests/api/test_splits_api.py b/tests/api/test_splits_api.py index a3e17206..5e827cf9 100644 --- a/tests/api/test_splits_api.py +++ b/tests/api/test_splits_api.py @@ -19,7 +19,7 @@ def test_fetch_split_changes(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.get.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = split_api.fetch_splits(123) diff --git a/tests/api/test_telemetry.py b/tests/api/test_telemetry.py index 60c6f54c..abec2d50 100644 --- a/tests/api/test_telemetry.py +++ b/tests/api/test_telemetry.py @@ -38,7 +38,7 @@ def test_post_latencies(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = telemetry_api.flush_latencies({ @@ -75,7 +75,7 @@ def test_post_counters(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = telemetry_api.flush_counters({'counter1': 1, 'counter2': 2}) @@ -110,7 +110,7 @@ def test_post_gauge(self, mocker): httpclient.reset_mock() def raise_exception(*args, **kwargs): - raise client.HttpClientException('some_message', Exception('something')) + raise client.HttpClientException('some_message') httpclient.post.side_effect = raise_exception with pytest.raises(APIException) as exc_info: response = telemetry_api.flush_gauges({'gauge1': 1, 'gauge2': 2}) From 385ee2a0362c2d59382eb122aad118c5c49d68d8 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 16 Apr 2019 18:55:40 -0300 Subject: [PATCH 09/38] docstring cleanup --- .github/pull_request_template.md | 3 +- splitio/__init__.py | 30 ++----------------- splitio/api/client.py | 40 ++++++++++++++++--------- splitio/api/events.py | 6 ++-- splitio/client/factory.py | 12 ++++++-- splitio/models/grammar/matchers/keys.py | 6 ---- splitio/storage/inmemmory.py | 31 +++++++++++++++++-- splitio/storage/uwsgi.py | 24 ++++++++++++--- splitio/tasks/events_sync.py | 7 ++--- splitio/tasks/impressions_sync.py | 7 ++--- splitio/tasks/util/asynctask.py | 4 +-- 11 files changed, 98 insertions(+), 72 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 63bd9a93..002b3add 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -8,7 +8,8 @@ * Bullet 2 ## How to test new changes? -* +* python setup.py test to test everythin +* pytest to test a file in particular (requires pytest, pytest-cov & pytest-mock to be installed) ## Extra Notes * Bullet 1 diff --git a/splitio/__init__.py b/splitio/__init__.py index ea266440..b5cea2b5 100644 --- a/splitio/__init__.py +++ b/splitio/__init__.py @@ -2,31 +2,5 @@ unicode_literals from splitio.client.factory import get_factory -#from .factories import get_factory # noqa -#from .key import Key # noqa -#from .version import __version__ # noqa -# -#__all__ = ('api', 'brokers', 'cache', 'clients', 'matchers', 'segments', -# 'settings', 'splits', 'splitters', 'transformers', 'treatments', -# 'version', 'factories', 'manager') -# -# -## Functions defined to maintain compatibility with previous sdk versions. -## ====================================================================== -## -## This functions are not supposed to be used directly, factory method should be -## called instead, but since they were previously exposed, they're re-added here -## as helper function so that if someone was using we don't break their code. -# -#def get_client(apikey, **kwargs): -# from .clients import Client -# from .brokers import get_self_refreshing_broker -# broker = get_self_refreshing_broker(apikey, **kwargs) -# return Client(broker) -# -# -#def get_redis_client(apikey, **kwargs): -# from .clients import Client -# from .brokers import get_redis_broker -# broker = get_redis_broker(apikey, **kwargs) -# return Client(broker) +from splitio.client.key import Key +from splitio.version import __version__ diff --git a/splitio/api/client.py b/splitio/api/client.py index 00fd2ede..2b37f92f 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -12,16 +12,14 @@ class HttpClientException(Exception): """HTTP Client exception.""" - def __init__(self, custom_message): + def __init__(self, message): """ Class constructor. :param message: Information on why this exception happened. :type message: str - :param original_exception: Original exception being caught if any. - :type original_exception: Exception. """ - Exception.__init__(self, custom_message) + Exception.__init__(self, message) class HttpClient(object): @@ -34,12 +32,12 @@ def __init__(self, timeout=None, sdk_url=None, events_url=None): """ Class constructor. + :param timeout: How many milliseconds to wait until the server responds. + :type timeout: int :param sdk_url: Optional alternative sdk URL. :type sdk_url: str :param events_url: Optional alternative events URL. :type events_url: str - :param timeout: How many milliseconds to wait until the server responds. - :type timeout: int """ self._timeout = timeout / 1000 if timeout else None # Convert ms to seconds. self._urls = { @@ -61,24 +59,37 @@ def _build_url(self, server, path): """ return self._urls[server] + path + def _build_basic_headers(self, apikey): + """ + Build basic headers with auth. + + :param apikey: API token used to identify backend calls. + :type apikey: str + """ + return { + 'Content-Type': 'application/json', + 'Authorization': "Bearer %s" % apikey + } + def get(self, server, path, apikey, query=None, extra_headers=None): #pylint: disable=too-many-arguments """ Issue a get request. + :param server: Whether the request is for SDK server or Events server. + :typee server: str :param path: path to append to the host url. :type path: str :param apikey: api token. :type apikey: str + :param query: Query string passed as dictionary. + :type query: dict :param extra_headers: key/value pairs of possible extra headers. :type extra_headers: dict :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer %s" % apikey - } + headers = self._build_basic_headers(apikey) if extra_headers is not None: headers.update(extra_headers) @@ -98,22 +109,23 @@ def post(self, server, path, apikey, body, query=None, extra_headers=None): #py """ Issue a POST request. + :param server: Whether the request is for SDK server or Events server. + :typee server: str :param path: path to append to the host url. :type path: str :param apikey: api token. :type apikey: str :param body: body sent in the request. :type body: str + :param query: Query string passed as dictionary. + :type query: dict :param extra_headers: key/value pairs of possible extra headers. :type extra_headers: dict :return: Tuple of status_code & response text :rtype: HttpResponse """ - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer %s" % apikey - } + headers = self._build_basic_headers(apikey) if extra_headers is not None: headers.update(extra_headers) diff --git a/splitio/api/events.py b/splitio/api/events.py index bcb4a73b..2c929634 100644 --- a/splitio/api/events.py +++ b/splitio/api/events.py @@ -14,10 +14,12 @@ def __init__(self, http_client, apikey, sdk_metadata): """ Class constructor. - :param client: HTTP Client responsble for issuing calls to the backend. - :type client: HttpClient + :param http_client: HTTP Client responsble for issuing calls to the backend. + :type http_client: HttpClient :param apikey: User apikey token. :type apikey: string + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata """ self._logger = logging.getLogger(self.__class__.__name__) self._client = http_client diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 1d5164bb..889bc5fb 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -74,8 +74,16 @@ def __init__( #pylint: disable=too-many-arguments :param storages: Dictionary of storages for all split models. :type storages: dict - :param tasks: Dictionary of synchronization tasks. + :param labels_enabled: Whether the impressions should store labels or not. + :type labels_enabled: bool + :param apis: Dictionary of apis client wrappers + :type apis: dict + :param tasks: Dictionary of sychronization tasks :type tasks: dict + :param sdk_ready_flag: Event to set when the sdk is ready. + :type sdk_ready_flag: threading.Event + :param impression_listener: User custom listener to handle impressions locally. + :type impression_listener: splitio.client.listener.ImpressionListener """ self._logger = logging.getLogger(self.__class__.__name__) self._storages = storages @@ -139,7 +147,7 @@ def block_until_ready(self, timeout=None): ready = self._sdk_ready_flag.wait(timeout) if not ready: - raise TimeoutException('Waited %d seconds, and sdk was not ready') + raise TimeoutException('SDK Initialization: time of %d exceeded' % timeout) def destroy(self, destroyed_event=None): """ diff --git a/splitio/models/grammar/matchers/keys.py b/splitio/models/grammar/matchers/keys.py index 95741b72..6fcc2584 100644 --- a/splitio/models/grammar/matchers/keys.py +++ b/splitio/models/grammar/matchers/keys.py @@ -41,12 +41,6 @@ def _add_matcher_specific_properties_to_json(self): return {} - - - - - - class UserDefinedSegmentMatcher(Matcher): """Matcher that returns true when the submitted key belongs to a segment.""" diff --git a/splitio/storage/inmemmory.py b/splitio/storage/inmemmory.py index 655acee0..cfb240cb 100644 --- a/splitio/storage/inmemmory.py +++ b/splitio/storage/inmemmory.py @@ -14,6 +14,7 @@ class InMemorySplitStorage(SplitStorage): def __init__(self): """Constructor.""" + self._logger = logging.getLogger(self.__class__.__name__) self._lock = threading.RLock() self._splits = {} self._change_number = -1 @@ -55,6 +56,7 @@ def remove(self, split_name): self._splits.pop(split_name) return True except KeyError: + self._logger.warning("Tried to delete nonexistant split %s. Skipping", split_name) return False def get_change_number(self): @@ -102,6 +104,7 @@ class InMemorySegmentStorage(SegmentStorage): def __init__(self): """Constructor.""" + self._logger = logging.getLogger(self.__class__.__name__) self._segments = {} self._change_numbers = {} self._lock = threading.RLock() @@ -116,7 +119,13 @@ def get(self, segment_name): :rtype: str """ with self._lock: - return self._segments.get(segment_name) + fetched = self._segments.get(segment_name) + if fetched is None: + self._logger.warning( + "Tried to retrieve nonexistant segment %s. Skipping", + segment_name + ) + return fetched def put(self, segment): """ @@ -189,7 +198,13 @@ def segment_contains(self, segment_name, key): :rtype: bool """ with self._lock: - return segment_name in self._segments and self._segments[segment_name].contains(key) + if not segment_name in self._segments: + self._logger.warning( + "Tried to query members for nonexistant segment %s. Returning False", + segment_name + ) + return False + return self._segments[segment_name].contains(key) class InMemoryImpressionStorage(ImpressionStorage): @@ -201,6 +216,7 @@ def __init__(self, queue_size): :param eventsQueueSize: How many events to queue before forcing a submission """ + self._logger = logging.getLogger(self.__class__.__name__) self._impressions = queue.Queue(maxsize=queue_size) self._lock = threading.Lock() self._queue_full_hook = None @@ -229,6 +245,10 @@ def put(self, impressions): except queue.Full: if self._queue_full_hook is not None and callable(self._queue_full_hook): self._queue_full_hook() + self._logger.warning( + 'Event queue is full, failing to add more events. \n' + 'Consider increasing parameter `eventQueueSize` in configuration' + ) return False def pop_many(self, count): @@ -259,6 +279,7 @@ def __init__(self, eventsQueueSize): :param eventsQueueSize: How many events to queue before forcing a submission """ + self._logger = logging.getLogger(self.__class__.__name__) self._lock = threading.Lock() self._events = queue.Queue(maxsize=eventsQueueSize) self._queue_full_hook = None @@ -286,6 +307,10 @@ def put(self, events): except queue.Full: if self._queue_full_hook is not None and callable(self._queue_full_hook): self._queue_full_hook() + self._logger.warning( + 'Impressions queue is full, failing to add more events. \n' + 'Consider increasing parameter `impressionsQueueSize` in configuration' + ) return False def pop_many(self, count): @@ -325,7 +350,7 @@ def inc_latency(self, name, bucket): :tyoe value: int """ if not 0 <= bucket <= 21: - self._logger.error('Incorect bucket "%d" for latency "%s". Ignoring.', bucket, name) + self._logger.warning('Incorect bucket "%d" for latency "%s". Ignoring.', bucket, name) return with self._latencies_lock: diff --git a/splitio/storage/uwsgi.py b/splitio/storage/uwsgi.py index 77c1da67..737049ea 100644 --- a/splitio/storage/uwsgi.py +++ b/splitio/storage/uwsgi.py @@ -30,6 +30,7 @@ def __init__(self, uwsgi_entrypoint): :param uwsgi_entrypoint: UWSGI module. Can be the actual module or a mock. :type uwsgi_entrypoint: module """ + self._logger = logging.getLogger(self.__class__.__name__) self._uwsgi = uwsgi_entrypoint def get(self, split_name): @@ -45,7 +46,10 @@ def get(self, split_name): self._KEY_TEMPLATE.format(suffix=split_name), _SPLITIO_SPLITS_CACHE_NAMESPACE ) - return splits.from_raw(json.loads(raw)) if raw is not None else None + to_return = splits.from_raw(json.loads(raw)) if raw is not None else None + if not to_return: + self._logger.warning("Trying to retrieve nonexistant split %s. Ignoring.", split_name) + return to_return def put(self, split): """ @@ -102,10 +106,13 @@ def remove(self, split_name): # Split list not found, no need to delete anything pass - return self._uwsgi.cache_del( + result = self._uwsgi.cache_del( self._KEY_TEMPLATE.format(suffix=split_name), _SPLITIO_SPLITS_CACHE_NAMESPACE ) + if not result is False: + self._logger.warning("Trying to retrieve nonexistant split %s. Ignoring.", split_name) + return result def get_change_number(self): """ @@ -168,6 +175,7 @@ def __init__(self, uwsgi_entrypoint): :param uwsgi_entrypoint: UWSGI module. Can be the actual module or a mock. :type uwsgi_entrypoint: module """ + self._logger = logging.getLogger(self.__class__.__name__) self._uwsgi = uwsgi_entrypoint def get(self, segment_name): @@ -192,6 +200,10 @@ def get(self, segment_name): 'till': change_number }) except TypeError: + self._logger.warning( + "Trying to retrieve nonexistant segment %s. Ignoring.", + segment_name + ) return None def update(self, segment_name, to_add, to_remove, change_number=None): @@ -297,6 +309,7 @@ def __init__(self, adapter): :param adapter: UWSGI Adapter/Emulator/Module. :type: object """ + self._logger = logging.getLogger(self.__class__.__name__) self._uwsgi = adapter def put(self, impressions): @@ -334,7 +347,8 @@ def pop_many(self, count): self._IMPRESSIONS_KEY, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE )) except TypeError: - current = [] + return [] + self._uwsgi.cache_update( self._IMPRESSIONS_KEY, json.dumps(current[count:]), @@ -388,6 +402,7 @@ def __init__(self, adapter): :param adapter: UWSGI Adapter/Emulator/Module. :type: object """ + self._logger = logging.getLogger(self.__class__.__name__) self._uwsgi = adapter def put(self, events): @@ -424,7 +439,8 @@ def pop_many(self, count): self._EVENTS_KEY, _SPLITIO_EVENTS_CACHE_NAMESPACE )) except TypeError: - current = [] + return [] + self._uwsgi.cache_update( self._EVENTS_KEY, json.dumps(current[count:]), diff --git a/splitio/tasks/events_sync.py b/splitio/tasks/events_sync.py index 3dac401a..a938518d 100644 --- a/splitio/tasks/events_sync.py +++ b/splitio/tasks/events_sync.py @@ -69,14 +69,11 @@ def _send_events(self): return try: - status_code = self._events_api.flush_events(to_send) - if status_code >= 300: - self._logger.error("Event reporting failed with status code %d", status_code) - self._add_to_failed_queue(to_send) + self._events_api.flush_events(to_send) except APIException as exc: self._logger.error( 'Exception raised while reporting events: %s -- %d', - exc.custom_message, + exc.message, exc.status_code ) self._add_to_failed_queue(to_send) diff --git a/splitio/tasks/impressions_sync.py b/splitio/tasks/impressions_sync.py index cc3567f2..075f4547 100644 --- a/splitio/tasks/impressions_sync.py +++ b/splitio/tasks/impressions_sync.py @@ -70,14 +70,11 @@ def _send_impressions(self): return try: - status_code = self._impressions_api.flush_impressions(to_send) - if status_code >= 300: - self._logger.error("Impressions reporting failed with status code %s", status_code) - self._add_to_failed_queue(to_send) + self._impressions_api.flush_impressions(to_send) except APIException as exc: self._logger.error( 'Exception raised while reporting impressions: %s -- %d', - exc.custom_message, + exc.message, exc.status_code ) self._add_to_failed_queue(to_send) diff --git a/splitio/tasks/util/asynctask.py b/splitio/tasks/util/asynctask.py index 929202cc..88fecc00 100644 --- a/splitio/tasks/util/asynctask.py +++ b/splitio/tasks/util/asynctask.py @@ -133,9 +133,9 @@ def start(self): try: self._thread.start() - except RuntimeError as exc: + except RuntimeError: _LOGGER.error("Couldn't create new thread for async task") - _LOGGER.exception(exc) + _LOGGER.debug('Error: ', exc_info=True) def stop(self, event=None): """ From ebabec297c1747da853144a17acf636e2eee141b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 16 Apr 2019 22:24:49 -0300 Subject: [PATCH 10/38] fix input validation log interpolation --- splitio/api/client.py | 3 +- splitio/client/input_validator.py | 222 ++++++++++++++------------- tests/client/test_input_validator.py | 220 +++++++++++++------------- 3 files changed, 227 insertions(+), 218 deletions(-) diff --git a/splitio/api/client.py b/splitio/api/client.py index 2b37f92f..535e23fe 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -59,7 +59,8 @@ def _build_url(self, server, path): """ return self._urls[server] + path - def _build_basic_headers(self, apikey): + @staticmethod + def _build_basic_headers(apikey): """ Build basic headers with auth. diff --git a/splitio/client/input_validator.py b/splitio/client/input_validator.py index b692c011..b6d9d6fc 100644 --- a/splitio/client/input_validator.py +++ b/splitio/client/input_validator.py @@ -1,16 +1,17 @@ +"""Input validation module.""" from __future__ import absolute_import, division, print_function, \ unicode_literals from numbers import Number import logging -import six import re import math -import requests + +import six + +from splitio.api import APIException from splitio.client.key import Key from splitio.engine.evaluator import CONTROL -# from splitio.api import SdkApi -# from splitio.exceptions import NetworkingException _LOGGER = logging.getLogger(__name__) @@ -20,7 +21,7 @@ def _check_not_null(value, name, operation): """ - Checks if value is null + Check if value is null. :param key: value to be checked :type key: str @@ -32,15 +33,15 @@ def _check_not_null(value, name, operation): :rtype: True|False """ if value is None: - _LOGGER.error('{}: you passed a null {}, {} must be a non-empty string.' - .format(operation, name, name)) + _LOGGER.error('%s: you passed a null %s, %s must be a non-empty string.', + operation, name, name) return False return True def _check_is_string(value, name, operation): """ - Checks if value is not string + Check if value is not string. :param key: value to be checked :type key: str @@ -52,15 +53,17 @@ def _check_is_string(value, name, operation): :rtype: True|False """ if isinstance(value, six.string_types) is False: - _LOGGER.error('{}: you passed an invalid {}, {} must be a non-empty string.'.format( - operation, name, name)) + _LOGGER.error( + '%s: you passed an invalid %s, %s must be a non-empty string.', + operation, name, name + ) return False return True def _check_string_not_empty(value, name, operation): """ - Checks if value is an empty string + Check if value is an empty string. :param key: value to be checked :type key: str @@ -72,15 +75,15 @@ def _check_string_not_empty(value, name, operation): :rtype: True|False """ if value.strip() == "": - _LOGGER.error('{}: you passed an empty {}, {} must be a non-empty string.' - .format(operation, name, name)) + _LOGGER.error('%s: you passed an empty %s, %s must be a non-empty string.', + operation, name, name) return False return True def _check_string_matches(value, operation, pattern): """ - Checks if value is adhere to a regular expression passed + Check if value is adhere to a regular expression passed. :param key: value to be checked :type key: str @@ -92,19 +95,21 @@ def _check_string_matches(value, operation, pattern): :rtype: True|False """ if not re.match(pattern, value): - _LOGGER.error('{}: you passed {}, event_type must '.format(operation, value) + - 'adhere to the regular expression {}. '.format(pattern) + - 'This means an event name must be alphanumeric, cannot be more ' + - 'than 80 characters long, and can only include a dash, underscore, ' + - 'period, or colon as separators of alphanumeric characters.' - ) + _LOGGER.error( + '%s: you passed %s, event_type must ' + + 'adhere to the regular expression %s. ' + + 'This means an event name must be alphanumeric, cannot be more ' + + 'than 80 characters long, and can only include a dash, underscore, ' + + 'period, or colon as separators of alphanumeric characters.', + operation, value, pattern + ) return False return True def _check_can_convert(value, name, operation): """ - Checks if is a valid convertion. + Check if is a valid convertion. :param key: value to be checked :type key: bool|number|array| @@ -121,17 +126,17 @@ def _check_can_convert(value, name, operation): # check whether if isnan and isinf are really necessary if isinstance(value, bool) or (not isinstance(value, Number)) or math.isnan(value) \ or math.isinf(value): - _LOGGER.error('{}: you passed an invalid {}, {} must be a non-empty string.' - .format(operation, name, name)) + _LOGGER.error('%s: you passed an invalid %s, %s must be a non-empty string.', + operation, name, name) return None - _LOGGER.warning('{}: {} {} is not of type string, converting.' - .format(operation, name, value)) + _LOGGER.warning('%s: %s %s is not of type string, converting.', + operation, name, value) return str(value) def _check_valid_length(value, name, operation): """ - Checks value's length + Check value's length. :param key: value to be checked :type key: str @@ -143,16 +148,15 @@ def _check_valid_length(value, name, operation): :rtype: True|False """ if len(value) > MAX_LENGTH: - _LOGGER.error('{}: {} too long - must be {} characters or less.' - .format(operation, name, MAX_LENGTH)) + _LOGGER.error('%s: %s too long - must be %s characters or less.', + operation, name, MAX_LENGTH) return False return True def _check_valid_object_key(key, name, operation): """ - Checks if object key is valid for get_treatment/s when is - sent as Key Object + Check if object key is valid for get_treatment/s when is sent as Key Object. :param key: key to be checked :type key: str @@ -164,21 +168,22 @@ def _check_valid_object_key(key, name, operation): :rtype: str|None """ if key is None: - _LOGGER.error('{}: you passed a null {}, '.format(operation, name) - + '{} must be a non-empty string.'.format(name)) + _LOGGER.error( + '%s: you passed a null %s, %s must be a non-empty string.', + operation, name, name) return None if isinstance(key, six.string_types): if not _check_string_not_empty(key, name, operation): return None - keyStr = _check_can_convert(key, name, operation) - if keyStr is None or not _check_valid_length(keyStr, name, operation): + key_str = _check_can_convert(key, name, operation) + if key_str is None or not _check_valid_length(key_str, name, operation): return None - return keyStr + return key_str def _remove_empty_spaces(value, operation): """ - Checks if an string has whitespaces + Check if an string has whitespaces. :param value: value to be checked :type value: str @@ -189,15 +194,15 @@ def _remove_empty_spaces(value, operation): """ strip_value = value.strip() if value != strip_value: - _LOGGER.warning("{}: feature_name '{}' has extra whitespace,".format(operation, value) - + " trimming.") + _LOGGER.warning("%s: feature_name '%s' has extra whitespace, trimming.", operation, value) return strip_value def validate_key(key, operation): """ - Validate Key parameter for get_treatment/s, if is invalid at some point - the bucketing_key or matching_key it will return None + Validate Key parameter for get_treatment/s. + + If the matching or bucketing key is invalid, will return None. :param key: user key :type key: mixed @@ -209,8 +214,7 @@ def validate_key(key, operation): matching_key_result = None bucketing_key_result = None if key is None: - _LOGGER.error('{}: you passed a null key, key must be a non-empty string.' - .format(operation)) + _LOGGER.error('%s: you passed a null key, key must be a non-empty string.', operation) return None, None if isinstance(key, Key): @@ -222,17 +226,17 @@ def validate_key(key, operation): if bucketing_key_result is None: return None, None else: - keyStr = _check_can_convert(key, 'key', operation) - if keyStr is not None and \ - _check_string_not_empty(keyStr, 'key', operation) and \ - _check_valid_length(keyStr, 'key', operation): - matching_key_result = keyStr + key_str = _check_can_convert(key, 'key', operation) + if key_str is not None and \ + _check_string_not_empty(key_str, 'key', operation) and \ + _check_valid_length(key_str, 'key', operation): + matching_key_result = key_str return matching_key_result, bucketing_key_result def validate_feature_name(feature_name): """ - Checks if feature_name is valid for get_treatment + Check if feature_name is valid for get_treatment. :param feature_name: feature_name to be checked :type feature_name: str @@ -248,7 +252,7 @@ def validate_feature_name(feature_name): def validate_track_key(key): """ - Checks if key is valid for track + Check if key is valid for track. :param key: key to be checked :type key: str @@ -257,17 +261,17 @@ def validate_track_key(key): """ if not _check_not_null(key, 'key', 'track'): return None - keyStr = _check_can_convert(key, 'key', 'track') - if keyStr is None or \ - (not _check_string_not_empty(keyStr, 'key', 'track')) or \ - (not _check_valid_length(keyStr, 'key', 'track')): + key_str = _check_can_convert(key, 'key', 'track') + if key_str is None or \ + (not _check_string_not_empty(key_str, 'key', 'track')) or \ + (not _check_valid_length(key_str, 'key', 'track')): return None - return keyStr + return key_str def validate_traffic_type(traffic_type): """ - Checks if traffic_type is valid for track + Check if traffic_type is valid for track. :param traffic_type: traffic_type to be checked :type traffic_type: str @@ -279,15 +283,15 @@ def validate_traffic_type(traffic_type): (not _check_string_not_empty(traffic_type, 'traffic_type', 'track')): return None if not traffic_type.islower(): - _LOGGER.warning('track: {} should be all lowercase - converting string to lowercase.' - .format(traffic_type)) + _LOGGER.warning('track: %s should be all lowercase - converting string to lowercase.', + traffic_type) traffic_type = traffic_type.lower() return traffic_type def validate_event_type(event_type): """ - Checks if event_type is valid for track + Check if event_type is valid for track. :param event_type: event_type to be checked :type event_type: str @@ -304,7 +308,7 @@ def validate_event_type(event_type): def validate_value(value): """ - Checks if value is valid for track + Check if value is valid for track. :param value: value to be checked :type value: number @@ -321,7 +325,7 @@ def validate_value(value): def validate_manager_feature_name(feature_name): """ - Checks if feature_name is valid for track + Check if feature_name is valid for track. :param feature_name: feature_name to be checked :type feature_name: str @@ -335,9 +339,9 @@ def validate_manager_feature_name(feature_name): return feature_name -def validate_features_get_treatments(features): +def validate_features_get_treatments(features): #pylint: disable=invalid-name """ - Checks if features is valid for get_treatments + Check if features is valid for get_treatments. :param features: array of features :type features: list @@ -347,15 +351,16 @@ def validate_features_get_treatments(features): if features is None or not isinstance(features, list): _LOGGER.error('get_treatments: feature_names must be a non-empty array.') return None - if len(features) == 0: + if not features: _LOGGER.error('get_treatments: feature_names must be a non-empty array.') return [] - filtered_features = set(_remove_empty_spaces(feature, 'get_treatments') for feature in features - if feature is not None and - _check_is_string(feature, 'feature_name', 'get_treatments') and - _check_string_not_empty(feature, 'feature_name', 'get_treatments') - ) - if len(filtered_features) == 0: + filtered_features = set( + _remove_empty_spaces(feature, 'get_treatments') for feature in features + if feature is not None and + _check_is_string(feature, 'feature_name', 'get_treatments') and + _check_string_not_empty(feature, 'feature_name', 'get_treatments') + ) + if not filtered_features: _LOGGER.error('get_treatments: feature_names must be a non-empty array.') return None return filtered_features @@ -363,7 +368,7 @@ def validate_features_get_treatments(features): def generate_control_treatments(features): """ - Generates valid features to control + Generate valid features to control. :param features: array of features :type features: list @@ -375,7 +380,7 @@ def generate_control_treatments(features): def validate_attributes(attributes, operation): """ - Checks if attributes is valid + Check if attributes is valid. :param attributes: dict :type attributes: dict @@ -386,43 +391,50 @@ def validate_attributes(attributes, operation): """ if attributes is None: return True - if not type(attributes) is dict: - _LOGGER.error('{}: attributes must be of type dictionary.' - .format(operation)) + if not isinstance(attributes, dict): + _LOGGER.error('%s: attributes must be of type dictionary.', operation) return False return True -# TODO: Fix this! -# def _valid_apikey_type(api_key, sdk_api_base_url): -# sdk_api = SdkApi( -# api_key, -# sdk_api_base_url=sdk_api_base_url, -# ) -# _SEGMENT_CHANGES_URL_TEMPLATE = '{base_url}/segmentChanges/{segment_name}/' -# url = _SEGMENT_CHANGES_URL_TEMPLATE.format(base_url=sdk_api_base_url, -# segment_name='___TEST___') -# params = { -# 'since': -1 -# } -# headers = sdk_api._build_headers() -# try: -# response = requests.get(url, params=params, headers=headers, timeout=sdk_api._timeout) -# if response.status_code == requests.codes.forbidden: -# return False -# return True -# except requests.exceptions.RequestException: -# raise NetworkingException() - - -def validate_factory_instantiation(apikey, config, sdk_api_base_url): - """ - Checks if is a valid instantiation of split client +class _ApiLogFilter(logging.Filter): # pylint: disable=too-few-public-methods + def filter(self, record): + return record.name not in ('SegmentsAPI', 'HttpClient') + + +def _valid_apikey_type(segment_api): + """ + Try to guess if the apikey is of browser type and let the user know. + + :param segment_api: Segments API client. + :type segment_api: splitio.api.segments.SegmentsAPI + """ + api_messages_filter = _ApiLogFilter() + try: + segment_api._client._logger.addFilter(api_messages_filter) #pylint: disable=protected-access + segment_api._logger.addFilter(api_messages_filter) #pylint: disable=protected-access + segment_api.fetch_segment('__SOME_INVALID_SEGMENT__', -1) + except APIException as exc: + if exc.status_code == 403: + return False + finally: + segment_api._client._logger.removeFilter(api_messages_filter) #pylint: disable=protected-access + segment_api._logger.removeFilter(api_messages_filter) #pylint: disable=protected-access + + # True doesn't mean that the APIKEY is right, only that it's not of type "browser" + return True + + +def validate_factory_instantiation(apikey, segment_api): + """ + Check if the factory if being instantiated with the appropriate arguments. :param apikey: str :type apikey: str :param config: dict :type config: dict + :param segment_api: Segment API client + :type segment_api: splitio.api.segments.SegmentsAPI :return: bool :rtype: True|False """ @@ -432,17 +444,13 @@ def validate_factory_instantiation(apikey, config, sdk_api_base_url): (not _check_is_string(apikey, 'apikey', 'factory_instantiation')) or \ (not _check_string_not_empty(apikey, 'apikey', 'factory_instantiation')): return False - if 'ready' not in config or isinstance(config.get('ready'), bool) or \ - not isinstance(config.get('ready'), Number): - _LOGGER.error('no ready parameter has been set - incorrect control treatments ' - + 'could be logged') - return False try: - if not _valid_apikey_type(apikey, sdk_api_base_url): + if not _valid_apikey_type(segment_api): _LOGGER.error('factory instantiation: you passed a browser type ' + 'api_key, please grab an api key from the Split ' + 'console that is of type sdk') return False return True - except NetworkingException: - _LOGGER.error("Error occured when tried to connect with Split servers") + except Exception: #pylint: disable=broad-except + _LOGGER.error("Something went wrong when trying to check apikey type.") + _LOGGER.debug("Error: ", exc_info=True) diff --git a/tests/client/test_input_validator.py b/tests/client/test_input_validator.py index fe64c2a8..219a00f6 100644 --- a/tests/client/test_input_validator.py +++ b/tests/client/test_input_validator.py @@ -1,4 +1,6 @@ """Unit tests for the input_validator module.""" +#pylint: disable=line-too-long,protected-access,no-self-use,too-many-statements + from __future__ import absolute_import, division, print_function, \ unicode_literals @@ -8,10 +10,7 @@ from splitio.client.key import Key from splitio.storage import SplitStorage, EventStorage, ImpressionStorage, TelemetryStorage, \ SegmentStorage -from splitio.models.splits import Split, SplitView -from splitio.models.grammar.condition import Condition -from splitio.models.grammar.partitions import Partition -from splitio.client import input_validator +from splitio.models.splits import Split class ClientInputValidationTests(object): @@ -49,184 +48,184 @@ def _get_storage_mock(storage): assert client.get_treatment(None, 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed a null key, key must be a non-empty string.') + mocker.call('%s: you passed a null key, key must be a non-empty string.', 'get_treatment') ] client._logger.reset_mock() assert client.get_treatment('', 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an empty key, key must be a non-empty string.') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') ] client._logger.reset_mock() - key = ''.join('a' for _ in range(0,255)) + key = ''.join('a' for _ in range(0, 255)) assert client.get_treatment(key, 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: key too long - must be 250 characters or less.') + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatment', 'key', 250) ] client._logger.reset_mock() assert client.get_treatment(12345, 'some_feature') == 'default_treatment' assert client._logger.warning.mock_calls == [ - mocker.call('get_treatment: key 12345 is not of type string, converting.') + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment', 'key', 12345) ] client._logger.reset_mock() assert client.get_treatment(float('nan'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') ] client._logger.reset_mock() assert client.get_treatment(float('inf'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') ] client._logger.reset_mock() assert client.get_treatment(True, 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') ] client._logger.reset_mock() assert client.get_treatment([], 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid key, key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'key', 'key') ] client._logger.reset_mock() assert client.get_treatment('some_key', None) == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed a null feature_name, feature_name must be a non-empty string.') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') ] client._logger.reset_mock() assert client.get_treatment('some_key', 123) == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') ] client._logger.reset_mock() - assert client.get_treatment('some_key', True) == CONTROL + assert client.get_treatment('some_key', True) == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') ] client._logger.reset_mock() - assert client.get_treatment('some_key', []) == CONTROL + assert client.get_treatment('some_key', []) == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid feature_name, feature_name must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') ] client._logger.reset_mock() - assert client.get_treatment('some_key', '') == CONTROL + assert client.get_treatment('some_key', '') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an empty feature_name, feature_name must be a non-empty string.') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'feature_name', 'feature_name') ] client._logger.reset_mock() - assert client.get_treatment('some_key', 'some_feature') == 'default_treatment' + assert client.get_treatment('some_key', 'some_feature') == 'default_treatment' assert client._logger.error.mock_calls == [] assert client._logger.warning.mock_calls == [] client._logger.reset_mock() - assert client.get_treatment(Key(None, 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(None, 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed a null matching_key, matching_key must be a non-empty string.') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') ] client._logger.reset_mock() - assert client.get_treatment(Key('', 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key('', 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an empty matching_key, matching_key must be a non-empty string.') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') ] client._logger.reset_mock() - assert client.get_treatment(Key(float('nan'), 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(float('nan'), 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') ] client._logger.reset_mock() - assert client.get_treatment(Key(float('inf'), 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(float('inf'), 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') ] client._logger.reset_mock() - assert client.get_treatment(Key(True, 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key(True, 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') ] client._logger.reset_mock() - assert client.get_treatment(Key([], 'bucketing_key'), 'some_feature') == CONTROL + assert client.get_treatment(Key([], 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid matching_key, matching_key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'matching_key', 'matching_key') ] client._logger.reset_mock() - assert client.get_treatment(Key(12345, 'bucketing_key'), 'some_feature') == 'default_treatment' + assert client.get_treatment(Key(12345, 'bucketing_key'), 'some_feature') == 'default_treatment' assert client._logger.warning.mock_calls == [ - mocker.call('get_treatment: matching_key 12345 is not of type string, ' 'converting.') + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment', 'matching_key', 12345) ] client._logger.reset_mock() - key = ''.join('a' for _ in range(0,255)) + key = ''.join('a' for _ in range(0, 255)) assert client.get_treatment(Key(key, 'bucketing_key'), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: matching_key too long - must be 250 characters or less.') + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatment', 'matching_key', 250) ] client._logger.reset_mock() assert client.get_treatment(Key('mathcing_key', None), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed a null bucketing_key, bucketing_key must be a non-empty string.') + mocker.call('%s: you passed a null %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') ] client._logger.reset_mock() assert client.get_treatment(Key('mathcing_key', True), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid bucketing_key, bucketing_key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') ] client._logger.reset_mock() assert client.get_treatment(Key('mathcing_key', []), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an invalid bucketing_key, bucketing_key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') ] client._logger.reset_mock() assert client.get_treatment(Key('mathcing_key', ''), 'some_feature') == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: you passed an empty bucketing_key, bucketing_key must be a non-empty string.') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatment', 'bucketing_key', 'bucketing_key') ] client._logger.reset_mock() assert client.get_treatment(Key('mathcing_key', 12345), 'some_feature') == 'default_treatment' assert client._logger.warning.mock_calls == [ - mocker.call('get_treatment: bucketing_key 12345 is not of type string, converting.') + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatment', 'bucketing_key', 12345) ] client._logger.reset_mock() assert client.get_treatment('mathcing_key', 'some_feature', True) == CONTROL assert client._logger.error.mock_calls == [ - mocker.call('get_treatment: attributes must be of type dictionary.') + mocker.call('%s: attributes must be of type dictionary.', 'get_treatment') ] client._logger.reset_mock() - assert client.get_treatment('mathcing_key', 'some_feature', {'test': 'test'}) =='default_treatment' + assert client.get_treatment('mathcing_key', 'some_feature', {'test': 'test'}) == 'default_treatment' assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.get_treatment('mathcing_key', 'some_feature', None) =='default_treatment' + assert client.get_treatment('mathcing_key', 'some_feature', None) == 'default_treatment' assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.get_treatment('mathcing_key', ' some_feature ', None) =='default_treatment' + assert client.get_treatment('mathcing_key', ' some_feature ', None) == 'default_treatment' assert client._logger.warning.mock_calls == [ - mocker.call('get_treatment: feature_name \' some_feature \' has extra whitespace, trimming.') + mocker.call('%s: feature_name \'%s\' has extra whitespace, trimming.', 'get_treatment', ' some_feature ') ] def test_track(self, mocker): @@ -244,143 +243,144 @@ def test_track(self, mocker): client._logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=client._logger) - assert client.track(None, "traffic_type", "event_type", 1) == False + assert client.track(None, "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed a null key, key must be a non-empty string.") + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'key', 'key') ] client._logger.reset_mock() - assert client.track("", "traffic_type", "event_type", 1) == False + assert client.track("", "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an empty key, key must be a non-empty string.") + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'track', 'key', 'key') ] client._logger.reset_mock() - assert client.track(12345, "traffic_type", "event_type", 1) == True + assert client.track(12345, "traffic_type", "event_type", 1) is True assert client._logger.warning.mock_calls == [ - mocker.call("track: key 12345 is not of type string, converting.") + mocker.call("%s: %s %s is not of type string, converting.", 'track', 'key', 12345) ] client._logger.reset_mock() - assert client.track(True, "traffic_type", "event_type", 1) == False + assert client.track(True, "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid key, key must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'key', 'key') ] client._logger.reset_mock() - assert client.track([], "traffic_type", "event_type", 1) == False + assert client.track([], "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid key, key must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'key', 'key') ] client._logger.reset_mock() - key = ''.join('a' for _ in range(0,255)) - assert client.track(key, "traffic_type", "event_type", 1) == False + key = ''.join('a' for _ in range(0, 255)) + assert client.track(key, "traffic_type", "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: key too long - must be 250 characters or less.") + mocker.call("%s: %s too long - must be %s characters or less.", 'track', 'key', 250) ] client._logger.reset_mock() - assert client.track("some_key", None, "event_type", 1) == False + assert client.track("some_key", None, "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed a null traffic_type, traffic_type must be a non-empty string.") + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') ] client._logger.reset_mock() - assert client.track("some_key", "", "event_type", 1) == False + assert client.track("some_key", "", "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an empty traffic_type, traffic_type must be a non-empty string.") + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') ] client._logger.reset_mock() - assert client.track("some_key", 12345, "event_type", 1) == False + assert client.track("some_key", 12345, "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') ] client._logger.reset_mock() - assert client.track("some_key", True, "event_type", 1) == False + assert client.track("some_key", True, "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') ] client._logger.reset_mock() - assert client.track("some_key", [], "event_type", 1) == False + assert client.track("some_key", [], "event_type", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid traffic_type, traffic_type must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'traffic_type', 'traffic_type') ] client._logger.reset_mock() - assert client.track("some_key", "TRAFFIC_type", "event_type", 1) == True + assert client.track("some_key", "TRAFFIC_type", "event_type", 1) is True assert client._logger.warning.mock_calls == [ - mocker.call("track: TRAFFIC_type should be all lowercase - converting string to lowercase.") + mocker.call("track: %s should be all lowercase - converting string to lowercase.", 'TRAFFIC_type') ] - assert client.track("some_key", "traffic_type", None, 1) == False + assert client.track("some_key", "traffic_type", None, 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed a null event_type, event_type must be a non-empty string.") + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "", 1) == False + assert client.track("some_key", "traffic_type", "", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an empty event_type, event_type must be a non-empty string.") + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", True, 1) == False + assert client.track("some_key", "traffic_type", True, 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", [], 1) == False + assert client.track("some_key", "traffic_type", [], 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", 12345, 1) == False + assert client.track("some_key", "traffic_type", 12345, 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed an invalid event_type, event_type must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'track', 'event_type', 'event_type') ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "@@", 1) == False + assert client.track("some_key", "traffic_type", "@@", 1) is False assert client._logger.error.mock_calls == [ - mocker.call("track: you passed @@, event_type must adhere to the regular " - "expression ^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$. This means " + mocker.call("%s: you passed %s, event_type must adhere to the regular " + "expression %s. This means " "an event name must be alphanumeric, cannot be more than 80 " "characters long, and can only include a dash, underscore, " - "period, or colon as separators of alphanumeric characters.") + "period, or colon as separators of alphanumeric characters.", + 'track', '@@', '^[a-zA-Z0-9][-_.:a-zA-Z0-9]{0,79}$') ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", None) == True + assert client.track("some_key", "traffic_type", "event_type", None) is True assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", 1) == True + assert client.track("some_key", "traffic_type", "event_type", 1) is True assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", 1.23) == True + assert client.track("some_key", "traffic_type", "event_type", 1.23) is True assert client._logger.error.mock_calls == [] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", "test") == False + assert client.track("some_key", "traffic_type", "event_type", "test") is False assert client._logger.error.mock_calls == [ mocker.call("track: value must be a number.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", True) == False + assert client.track("some_key", "traffic_type", "event_type", True) is False assert client._logger.error.mock_calls == [ mocker.call("track: value must be a number.") ] client._logger.reset_mock() - assert client.track("some_key", "traffic_type", "event_type", []) == False + assert client.track("some_key", "traffic_type", "event_type", []) is False assert client._logger.error.mock_calls == [ mocker.call("track: value must be a number.") ] @@ -410,38 +410,38 @@ def test_get_treatments(self, mocker): assert client.get_treatments(None, ['some_feature']) == {'some_feature': CONTROL} assert client._logger.error.mock_calls == [ - mocker.call('get_treatments: you passed a null key, key must be a non-empty string.') + mocker.call('%s: you passed a null key, key must be a non-empty string.', 'get_treatments') ] client._logger.reset_mock() assert client.get_treatments("", ['some_feature']) == {'some_feature': CONTROL} assert client._logger.error.mock_calls == [ - mocker.call('get_treatments: you passed an empty key, key must be a non-empty string.') + mocker.call('%s: you passed an empty %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') ] - key = ''.join('a' for _ in range(0,255)) + key = ''.join('a' for _ in range(0, 255)) client._logger.reset_mock() assert client.get_treatments(key, ['some_feature']) == {'some_feature': CONTROL} assert client._logger.error.mock_calls == [ - mocker.call('get_treatments: key too long - must be 250 characters or less.') + mocker.call('%s: %s too long - must be %s characters or less.', 'get_treatments', 'key', 250) ] client._logger.reset_mock() assert client.get_treatments(12345, ['some_feature']) == {'some_feature': 'default_treatment'} assert client._logger.warning.mock_calls == [ - mocker.call('get_treatments: key 12345 is not of type string, converting.') + mocker.call('%s: %s %s is not of type string, converting.', 'get_treatments', 'key', 12345) ] client._logger.reset_mock() assert client.get_treatments(True, ['some_feature']) == {'some_feature': CONTROL} assert client._logger.error.mock_calls == [ - mocker.call('get_treatments: you passed an invalid key, key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') ] client._logger.reset_mock() assert client.get_treatments([], ['some_feature']) == {'some_feature': CONTROL} assert client._logger.error.mock_calls == [ - mocker.call('get_treatments: you passed an invalid key, key must be a non-empty string.') + mocker.call('%s: you passed an invalid %s, %s must be a non-empty string.', 'get_treatments', 'key', 'key') ] client._logger.reset_mock() @@ -485,11 +485,11 @@ def test_get_treatments(self, mocker): client._logger.reset_mock() assert client.get_treatments('some_key', ['some ']) == {'some': 'default_treatment'} assert client._logger.warning.mock_calls == [ - mocker.call('get_treatments: feature_name \'some \' has extra whitespace, trimming.') + mocker.call('%s: feature_name \'%s\' has extra whitespace, trimming.', 'get_treatments', 'some ') ] -class ManagerInputValidationTests(object): +class ManagerInputValidationTests(object): #pylint: disable=too-few-public-methods """Manager input validation test cases.""" def test_split_(self, mocker): @@ -507,27 +507,27 @@ def test_split_(self, mocker): manager._logger = mocker.Mock() mocker.patch('splitio.client.input_validator._LOGGER', new=manager._logger) - assert manager.split(None) == None + assert manager.split(None) is None assert manager._logger.error.mock_calls == [ - mocker.call("split: you passed a null feature_name, feature_name must be a non-empty string.") + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') ] manager._logger.reset_mock() - assert manager.split("") == None + assert manager.split("") is None assert manager._logger.error.mock_calls == [ - mocker.call("split: you passed an empty feature_name, feature_name must be a non-empty string.") + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') ] manager._logger.reset_mock() - assert manager.split(True) == None + assert manager.split(True) is None assert manager._logger.error.mock_calls == [ - mocker.call("split: you passed an invalid feature_name, feature_name must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') ] manager._logger.reset_mock() - assert manager.split([]) == None + assert manager.split([]) is None assert manager._logger.error.mock_calls == [ - mocker.call("split: you passed an invalid feature_name, feature_name must be a non-empty string.") + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'split', 'feature_name', 'feature_name') ] manager._logger.reset_mock() From f2ed31ac9715e6dfd9430fed1fdd149092bd8909 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 16 Apr 2019 22:49:55 -0300 Subject: [PATCH 11/38] input validation for factory instantiation --- splitio/client/factory.py | 7 +++ splitio/client/input_validator.py | 20 ++---- tests/client/test_input_validator.py | 94 +++++++++++----------------- 3 files changed, 48 insertions(+), 73 deletions(-) diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 889bc5fb..ace96bd8 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -9,6 +9,7 @@ import six from splitio.client.client import Client +from splitio.client import input_validator from splitio.client.manager import SplitManager from splitio.client.config import DEFAULT_CONFIG from splitio.client import util @@ -195,6 +196,9 @@ def destroyed(self): def _build_in_memory_factory(api_key, config, sdk_url=None, events_url=None): #pylint: disable=too-many-locals """Build and return a split factory tailored to the supplied config.""" + if not input_validator.validate_factory_instantiation(api_key): + return None + cfg = DEFAULT_CONFIG.copy() cfg.update(config) http_client = HttpClient( @@ -212,6 +216,9 @@ def _build_in_memory_factory(api_key, config, sdk_url=None, events_url=None): # 'telemetry': TelemetryAPI(http_client, api_key, sdk_metadata) } + if not input_validator.validate_apikey_type(apis['segments']): + return None + storages = { 'splits': InMemorySplitStorage(), 'segments': InMemorySegmentStorage(), diff --git a/splitio/client/input_validator.py b/splitio/client/input_validator.py index b6d9d6fc..2365353b 100644 --- a/splitio/client/input_validator.py +++ b/splitio/client/input_validator.py @@ -402,7 +402,7 @@ def filter(self, record): return record.name not in ('SegmentsAPI', 'HttpClient') -def _valid_apikey_type(segment_api): +def validate_apikey_type(segment_api): """ Try to guess if the apikey is of browser type and let the user know. @@ -411,21 +411,22 @@ def _valid_apikey_type(segment_api): """ api_messages_filter = _ApiLogFilter() try: - segment_api._client._logger.addFilter(api_messages_filter) #pylint: disable=protected-access segment_api._logger.addFilter(api_messages_filter) #pylint: disable=protected-access segment_api.fetch_segment('__SOME_INVALID_SEGMENT__', -1) except APIException as exc: if exc.status_code == 403: + _LOGGER.error('factory instantiation: you passed a browser type ' + + 'api_key, please grab an api key from the Split ' + + 'console that is of type sdk') return False finally: - segment_api._client._logger.removeFilter(api_messages_filter) #pylint: disable=protected-access segment_api._logger.removeFilter(api_messages_filter) #pylint: disable=protected-access # True doesn't mean that the APIKEY is right, only that it's not of type "browser" return True -def validate_factory_instantiation(apikey, segment_api): +def validate_factory_instantiation(apikey): """ Check if the factory if being instantiated with the appropriate arguments. @@ -444,13 +445,4 @@ def validate_factory_instantiation(apikey, segment_api): (not _check_is_string(apikey, 'apikey', 'factory_instantiation')) or \ (not _check_string_not_empty(apikey, 'apikey', 'factory_instantiation')): return False - try: - if not _valid_apikey_type(segment_api): - _LOGGER.error('factory instantiation: you passed a browser type ' - + 'api_key, please grab an api key from the Split ' - + 'console that is of type sdk') - return False - return True - except Exception: #pylint: disable=broad-except - _LOGGER.error("Something went wrong when trying to check apikey type.") - _LOGGER.debug("Error: ", exc_info=True) + return True diff --git a/tests/client/test_input_validator.py b/tests/client/test_input_validator.py index 219a00f6..8efc8933 100644 --- a/tests/client/test_input_validator.py +++ b/tests/client/test_input_validator.py @@ -4,7 +4,9 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals -from splitio.client.factory import SplitFactory +import logging + +from splitio.client.factory import SplitFactory, get_factory from splitio.client.client import CONTROL, Client from splitio.client.manager import SplitManager from splitio.client.key import Key @@ -537,61 +539,35 @@ def test_split_(self, mocker): -#class TestInputSanitizationFactory(TestCase): -# -# def setUp(self): -# input_validator._LOGGER.error = mock.MagicMock() -# self.logger_error = input_validator._LOGGER.error -# -# def test_factory_with_null_apikey(self): -# self.assertEqual(None, get_factory(None)) -# self.logger_error \ -# .assert_called_once_with("factory_instantiation: you passed a null apikey, apikey" + -# " must be a non-empty string.") -# -# def test_factory_with_empty_apikey(self): -# self.assertEqual(None, get_factory('')) -# self.logger_error \ -# .assert_called_once_with("factory_instantiation: you passed an empty apikey, apikey" + -# " must be a non-empty string.") -# -# def test_factory_with_invalid_apikey(self): -# self.assertEqual(None, get_factory(True)) -# self.logger_error \ -# .assert_called_once_with("factory_instantiation: you passed an invalid apikey, apikey" + -# " must be a non-empty string.") -# -# def test_factory_with_invalid_apikey_redis(self): -# config = { -# 'redisDb': 0, -# 'redisHost': 'localhost' -# } -# self.assertNotEqual(None, get_factory(True, config=config)) -# self.logger_error.assert_not_called() -# -# def test_factory_with_invalid_config(self): -# config = { -# 'some': 0 -# } -# self.assertEqual(None, get_factory("apikey", config=config)) -# self.logger_error \ -# .assert_called_once_with('no ready parameter has been set - incorrect control ' -# + 'treatments could be logged') -# -# def test_factory_with_invalid_null_ready(self): -# config = { -# 'ready': None -# } -# self.assertEqual(None, get_factory("apikey", config=config)) -# self.logger_error \ -# .assert_called_once_with('no ready parameter has been set - incorrect control ' -# + 'treatments could be logged') -# -# def test_factory_with_invalid_ready(self): -# config = { -# 'ready': True -# } -# self.assertEqual(None, get_factory("apikey", config=config)) -# self.logger_error \ -# .assert_called_once_with('no ready parameter has been set - incorrect control ' -# + 'treatments could be logged') +class FactoryInputValidationTests(object): #pylint: disable=too-few-public-methods + """Factory instantiation input validation test cases.""" + + def test_input_validation_factory(self, mocker): + """Test the input validators for factory instantiation.""" + logger = mocker.Mock(spec=logging.Logger) + mocker.patch('splitio.client.input_validator._LOGGER', new=logger) + + assert get_factory(None) is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed a null %s, %s must be a non-empty string.", 'factory_instantiation', 'apikey', 'apikey') + ] + + logger.reset_mock() + assert get_factory('') is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed an empty %s, %s must be a non-empty string.", 'factory_instantiation', 'apikey', 'apikey') + ] + + logger.reset_mock() + assert get_factory(True) is None + assert logger.error.mock_calls == [ + mocker.call("%s: you passed an invalid %s, %s must be a non-empty string.", 'factory_instantiation', 'apikey', 'apikey') + ] + + logger.reset_mock() + assert get_factory(True, config={'uwsgiCache': True}) is not None + assert logger.error.mock_calls == [] + + logger.reset_mock() + assert get_factory(True, config={'redisHost': 'some-host'}) is not None + assert logger.error.mock_calls == [] From d0e58cb5f11726212563e94e6447fa2df66f7f00 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 17 Apr 2019 09:52:20 -0300 Subject: [PATCH 12/38] factory ready methods --- splitio/client/client.py | 5 ++++ splitio/client/factory.py | 56 ++++++++++++++++++++++-------------- tests/client/test_factory.py | 19 ++++++++++-- 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/splitio/client/client.py b/splitio/client/client.py index ecee50a3..3928caee 100644 --- a/splitio/client/client.py +++ b/splitio/client/client.py @@ -55,6 +55,11 @@ def destroy(self): """ self._factory.destroy() + @property + def ready(self): + """Return whether the SDK initialization has finished.""" + return self._factory.ready + @property def destroyed(self): """Return whether the factory holding this client has been destroyed.""" diff --git a/splitio/client/factory.py b/splitio/client/factory.py index ace96bd8..0f6d67e1 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -91,16 +91,19 @@ def __init__( #pylint: disable=too-many-arguments self._labels_enabled = labels_enabled self._apis = apis if apis else {} self._tasks = tasks if tasks else {} - self._status = Status.NOT_INITIALIZED self._sdk_ready_flag = sdk_ready_flag self._impression_listener = impression_listener - # If we have a ready flag, add a listener that updates the status - # to READY once the flag is set. + # If we have a ready flag, it means we have sync tasks that need to finish + # before the SDK client becomes ready. if self._sdk_ready_flag is not None: + self._status = Status.NOT_INITIALIZED + # add a listener that updates the status to READY once the flag is set. ready_updater = threading.Thread(target=self._update_status_when_ready) ready_updater.setDaemon(True) ready_updater.start() + else: + self._status = Status.READY def _update_status_when_ready(self): """Wait until the sdk is ready and update the status.""" @@ -150,6 +153,16 @@ def block_until_ready(self, timeout=None): if not ready: raise TimeoutException('SDK Initialization: time of %d exceeded' % timeout) + @property + def ready(self): + """ + Return whether the factory is ready. + + :return: True if the factory is ready. False otherwhise. + :rtype: bool + """ + return self._status == Status.READY + def destroy(self, destroyed_event=None): """ Destroy the factory and render clients unusable. @@ -164,24 +177,25 @@ def destroy(self, destroyed_event=None): self._logger.info('Factory already destroyed.') return - if destroyed_event is not None: - stop_events = {name: threading.Event() for name in self._tasks.keys()} - for name, task in six.iteritems(self._tasks): - task.stop(stop_events[name]) - - def _wait_for_tasks_to_stop(): - for event in stop_events.values(): - event.wait() - destroyed_event.set() - - wait_thread = threading.Thread(target=_wait_for_tasks_to_stop) - wait_thread.setDaemon(True) - wait_thread.start() - else: - for task in self._tasks.values(): - task.stop() - - self._status = Status.DESTROYED + try: + if destroyed_event is not None: + stop_events = {name: threading.Event() for name in self._tasks.keys()} + for name, task in six.iteritems(self._tasks): + task.stop(stop_events[name]) + + def _wait_for_tasks_to_stop(): + for event in stop_events.values(): + event.wait() + destroyed_event.set() + + wait_thread = threading.Thread(target=_wait_for_tasks_to_stop) + wait_thread.setDaemon(True) + wait_thread.start() + else: + for task in self._tasks.values(): + task.stop() + finally: + self._status = Status.DESTROYED @property def destroyed(self): diff --git a/tests/client/test_factory.py b/tests/client/test_factory.py index 2030d084..bb5a1339 100644 --- a/tests/client/test_factory.py +++ b/tests/client/test_factory.py @@ -83,6 +83,9 @@ def _segment_task_init_mock(self, api, storage, split_storage, period, event): assert factory._tasks['telemetry']._storage == factory._storages['telemetry'] assert factory._tasks['telemetry']._api == factory._apis['telemetry'] assert factory._labels_enabled is True + factory.block_until_ready() + time.sleep(1) # give a chance for the bg thread to set the ready status + assert factory.ready def test_redis_client_creation(self, mocker): """Test that a client with redis storage is created correctly.""" @@ -157,6 +160,9 @@ def test_redis_client_creation(self, mocker): )] assert factory._labels_enabled is False assert factory._impression_listener == 123 + factory.block_until_ready() + time.sleep(1) # give a chance for the bg thread to set the ready status + assert factory.ready def test_uwsgi_client_creation(self): @@ -171,6 +177,9 @@ def test_uwsgi_client_creation(self): assert factory._tasks == {} assert factory._labels_enabled is True assert factory._impression_listener == 123 + factory.block_until_ready() + time.sleep(1) # give a chance for the bg thread to set the ready status + assert factory.ready def test_destroy(self, mocker): """Test that tasks are shutdown and data is flushed when destroy is called.""" @@ -218,7 +227,9 @@ def _event_task_init_mock(self, api, storage, refresh_rate, bulk_size): # Start factory and make assertions factory = get_factory('some_api_key') - + factory.block_until_ready() + time.sleep(1) # give a chance for the bg thread to set the ready status + assert factory.ready assert factory.destroyed is False factory.destroy() @@ -284,9 +295,12 @@ def _telemetry_task_init_mock(self, api, storage, refresh_rate): # Start factory and make assertions factory = get_factory('some_api_key') - assert factory.destroyed is False + factory.block_until_ready() + time.sleep(1) # give a chance for the bg thread to set the ready status + assert factory.ready + event = threading.Event() factory.destroy(event) @@ -322,3 +336,4 @@ def _telemetry_task_init_mock(self, api, storage, refresh_rate): # a chance to run and set the main event. assert event.is_set() + assert factory.destroyed From a959773665f7b73b08405c31ae6fe00bb5f75e1d Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 17 Apr 2019 15:19:32 -0300 Subject: [PATCH 13/38] fix yaml/legacy selection and test it --- splitio/api/client.py | 4 +-- splitio/client/localhost.py | 2 +- tests/client/test_localhost.py | 56 +++++++++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/splitio/api/client.py b/splitio/api/client.py index 535e23fe..b322660b 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -25,8 +25,8 @@ def __init__(self, message): class HttpClient(object): """HttpClient wrapper.""" - SDK_URL = 'https://split.io/api' - EVENTS_URL = 'https://split.io/api' + SDK_URL = 'https://sdk.split.io/api' + EVENTS_URL = 'https://events.split.io/api' def __init__(self, timeout=None, sdk_url=None, events_url=None): """ diff --git a/splitio/client/localhost.py b/splitio/client/localhost.py index 28bad3e0..7ee2ea03 100644 --- a/splitio/client/localhost.py +++ b/splitio/client/localhost.py @@ -243,7 +243,7 @@ def _read_splits_from_yaml_file(cls, filename): def _update_splits(self): """Update splits in storage.""" _LOGGER.info('Synchronizing splits now.') - if self._filename.split('.')[-1].lower() in ('yaml', 'yml'): + if self._filename.lower().endswith(('.yaml', '.yml')): fetched = self._read_splits_from_yaml_file(self._filename) else: fetched = self._read_splits_from_legacy_file(self._filename) diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py index aed7b018..5da30500 100644 --- a/tests/client/test_localhost.py +++ b/tests/client/test_localhost.py @@ -7,7 +7,7 @@ from splitio.client import localhost from splitio.models.splits import Split from splitio.models.grammar.matchers import AllKeysMatcher - +from splitio.storage import SplitStorage class LocalHostStoragesTests(object): """Localhost storages test cases.""" @@ -140,3 +140,57 @@ def test_parse_yaml_file(self): 'WhitelistMatcher' not in c[c.index('AllKeysMatcher'):] if 'AllKeysMatcher' in c else True for c in condition_types ) + + def test_update_splits(self, mocker): + """Test update spltis.""" + parse_legacy = mocker.Mock() + parse_legacy.return_value = {} + parse_yaml = mocker.Mock() + parse_yaml.return_value = {} + storage_mock = mocker.Mock(spec=SplitStorage) + storage_mock.get_split_names.return_value = [] + + parse_legacy.reset_mock() + parse_yaml.reset_mock() + task = localhost.LocalhostSplitSynchronizationTask('something', storage_mock, 0, None) + task._read_splits_from_legacy_file = parse_legacy + task._read_splits_from_yaml_file = parse_yaml + task._update_splits() + assert parse_legacy.mock_calls == [mocker.call('something')] + assert parse_yaml.mock_calls == [] + + parse_legacy.reset_mock() + parse_yaml.reset_mock() + task = localhost.LocalhostSplitSynchronizationTask('something.yaml', storage_mock, 0, None) + task._read_splits_from_legacy_file = parse_legacy + task._read_splits_from_yaml_file = parse_yaml + task._update_splits() + assert parse_legacy.mock_calls == [] + assert parse_yaml.mock_calls == [mocker.call('something.yaml')] + + parse_legacy.reset_mock() + parse_yaml.reset_mock() + task = localhost.LocalhostSplitSynchronizationTask('something.yml', storage_mock, 0, None) + task._read_splits_from_legacy_file = parse_legacy + task._read_splits_from_yaml_file = parse_yaml + task._update_splits() + assert parse_legacy.mock_calls == [] + assert parse_yaml.mock_calls == [mocker.call('something.yml')] + + parse_legacy.reset_mock() + parse_yaml.reset_mock() + task = localhost.LocalhostSplitSynchronizationTask('something.YAML', storage_mock, 0, None) + task._read_splits_from_legacy_file = parse_legacy + task._read_splits_from_yaml_file = parse_yaml + task._update_splits() + assert parse_legacy.mock_calls == [] + assert parse_yaml.mock_calls == [mocker.call('something.YAML')] + + parse_legacy.reset_mock() + parse_yaml.reset_mock() + task = localhost.LocalhostSplitSynchronizationTask('yaml', storage_mock, 0, None) + task._read_splits_from_legacy_file = parse_legacy + task._read_splits_from_yaml_file = parse_yaml + task._update_splits() + assert parse_legacy.mock_calls == [mocker.call('yaml')] + assert parse_yaml.mock_calls == [] From 1daeb1d060a623a95fcca7a25383e6fb22226d1f Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 17 Apr 2019 16:04:16 -0300 Subject: [PATCH 14/38] add e2e tests for localhost yaml file --- tests/client/test_localhost.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py index 5da30500..573523dc 100644 --- a/tests/client/test_localhost.py +++ b/tests/client/test_localhost.py @@ -4,6 +4,7 @@ import os import tempfile +from splitio.client.factory import get_factory from splitio.client import localhost from splitio.models.splits import Split from splitio.models.grammar.matchers import AllKeysMatcher @@ -194,3 +195,21 @@ def test_update_splits(self, mocker): task._update_splits() assert parse_legacy.mock_calls == [mocker.call('yaml')] assert parse_yaml.mock_calls == [] + + def test_localhost_e2e(self): + """Instantiate a client with a YAML file and issue get_treatment() calls.""" + filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') + factory = get_factory('localhost', config={'splitFile': filename}) + client = factory.client() + assert client.get_treatment_with_config('key', 'my_feature') == ('on', '{"desc" : "this applies only to ON treatment"}') + assert client.get_treatment_with_config('only_key', 'my_feature') == ( + 'off', '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' + ) + assert client.get_treatment_with_config('another_key', 'my_feature') == ('control', None) + assert client.get_treatment_with_config('key2', 'other_feature') == ('on', None) + assert client.get_treatment_with_config('key3', 'other_feature') == ('on', None) + assert client.get_treatment_with_config('some_key', 'other_feature_2') == ('on', None) + assert client.get_treatment_with_config('key_whitelist', 'other_feature_3') == ('on', None) + assert client.get_treatment_with_config('any_other_key', 'other_feature_3') == ('off', None) + + From a57e7ecd6181147ab9f9f5dd88adc3213177a9f3 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 17 Apr 2019 16:29:29 -0300 Subject: [PATCH 15/38] add `configs` in testview --- splitio/models/splits.py | 5 +++-- tests/client/test_localhost.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/splitio/models/splits.py b/splitio/models/splits.py index 8a652fab..aaf1eb85 100644 --- a/splitio/models/splits.py +++ b/splitio/models/splits.py @@ -10,7 +10,7 @@ SplitView = namedtuple( 'SplitView', - ['name', 'traffic_type', 'killed', 'treatments', 'change_number'] + ['name', 'traffic_type', 'killed', 'treatments', 'change_number', 'configs'] ) @@ -191,7 +191,8 @@ def to_split_view(self): self.traffic_type_name, self.killed, list(set(part.treatment for cond in self.conditions for part in cond.partitions)), - self.change_number + self.change_number, + self._configurations if self._configurations is not None else {} ) @python_2_unicode_compatible diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py index 573523dc..ea10c783 100644 --- a/tests/client/test_localhost.py +++ b/tests/client/test_localhost.py @@ -212,4 +212,13 @@ def test_localhost_e2e(self): assert client.get_treatment_with_config('key_whitelist', 'other_feature_3') == ('on', None) assert client.get_treatment_with_config('any_other_key', 'other_feature_3') == ('off', None) + manager = factory.manager() + assert manager.split('my_feature').configs == { + 'on': '{"desc" : "this applies only to ON treatment"}', + 'off': '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' + } + assert manager.split('other_feature').configs == {} + assert manager.split('other_feature_2').configs == {} + assert manager.split('other_feature_3').configs == {} + From e1fc0a23e4bea93d93cdfc003dfe6b0da2a94def Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 22 Apr 2019 16:19:18 -0300 Subject: [PATCH 16/38] more e2e tests --- splitio/client/client.py | 22 +- splitio/client/util.py | 2 - splitio/storage/inmemmory.py | 3 +- tests/client/test_client.py | 9 +- .../files/segmentEmployeesChanges.json | 10 + .../files/segmentHumanBeignsChanges.json | 10 + tests/integration/files/splitChanges.json | 321 +++++++++++ tests/integration/test_client_e2e.py | 536 ++++++++++++++++++ 8 files changed, 896 insertions(+), 17 deletions(-) create mode 100644 tests/integration/files/segmentEmployeesChanges.json create mode 100644 tests/integration/files/segmentHumanBeignsChanges.json create mode 100644 tests/integration/files/splitChanges.json create mode 100644 tests/integration/test_client_e2e.py diff --git a/splitio/client/client.py b/splitio/client/client.py index c14c4362..3ad6c716 100644 --- a/splitio/client/client.py +++ b/splitio/client/client.py @@ -236,20 +236,18 @@ def get_treatments_with_config(self, key, features, attributes=None): 'feature ' + feature + ' returning CONTROL.') treatments[feature] = CONTROL, None self._logger.debug('Error: ', exc_info=True) - import traceback - traceback.print_exc() continue - # Register impressions - try: - if bulk_impressions: - self._record_stats(bulk_impressions, start, self._METRIC_GET_TREATMENTS) - for impression in bulk_impressions: - self._send_impression_to_listener(impression, attributes) - except Exception: #pylint: disable=broad-except - self._logger.error('get_treatments: An exception when trying to store ' - 'impressions.') - self._logger.debug('Error: ', exc_info=True) + # Register impressions + try: + if bulk_impressions: + self._record_stats(bulk_impressions, start, self._METRIC_GET_TREATMENTS) + for impression in bulk_impressions: + self._send_impression_to_listener(impression, attributes) + except Exception: #pylint: disable=broad-except + self._logger.error('get_treatments: An exception when trying to store ' + 'impressions.') + self._logger.debug('Error: ', exc_info=True) return treatments diff --git a/splitio/client/util.py b/splitio/client/util.py index d0465f86..af51297f 100644 --- a/splitio/client/util.py +++ b/splitio/client/util.py @@ -11,8 +11,6 @@ ) - - def _get_ip(): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: diff --git a/splitio/storage/inmemmory.py b/splitio/storage/inmemmory.py index cfb240cb..e8e0eccb 100644 --- a/splitio/storage/inmemmory.py +++ b/splitio/storage/inmemmory.py @@ -241,6 +241,7 @@ def put(self, impressions): with self._lock: for impression in impressions: self._impressions.put(impression, False) + print self._impressions.qsize() return True except queue.Full: if self._queue_full_hook is not None and callable(self._queue_full_hook): @@ -308,7 +309,7 @@ def put(self, events): if self._queue_full_hook is not None and callable(self._queue_full_hook): self._queue_full_hook() self._logger.warning( - 'Impressions queue is full, failing to add more events. \n' + 'Events queue is full, failing to add more events. \n' 'Consider increasing parameter `impressionsQueueSize` in configuration' ) return False diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 0f0339ad..8e040856 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,6 +1,8 @@ """SDK main client test module.""" #pylint: disable=no-self-use,protected-access +import json +import os from splitio.client.client import Client from splitio.client.factory import SplitFactory from splitio.engine.evaluator import Evaluator @@ -8,6 +10,9 @@ from splitio.models.events import Event from splitio.storage import EventStorage, ImpressionStorage, SegmentStorage, SplitStorage, \ TelemetryStorage +from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ + InMemoryImpressionStorage, InMemoryTelemetryStorage, InMemoryEventStorage +from splitio.models import splits, segments class ClientTests(object): #pylint: disable=too-few-public-methods """Split client test cases.""" @@ -197,7 +202,7 @@ def _raise(*_): raise Exception('something') client._evaluator.evaluate_treatment.side_effect = _raise assert client.get_treatments('key', ['f1', 'f2']) == {'f1': 'control', 'f2': 'control'} - assert len(telemetry_storage.inc_latency.mock_calls) == 2 + assert len(telemetry_storage.inc_latency.mock_calls) == 1 def test_get_treatments_with_config(self, mocker): """Test get_treatment execution paths.""" @@ -265,7 +270,7 @@ def _raise(*_): 'f1': ('control', None), 'f2': ('control', None) } - assert len(telemetry_storage.inc_latency.mock_calls) == 2 + assert len(telemetry_storage.inc_latency.mock_calls) == 1 def test_destroy(self, mocker): """Test that destroy/destroyed calls are forwarded to the factory.""" diff --git a/tests/integration/files/segmentEmployeesChanges.json b/tests/integration/files/segmentEmployeesChanges.json new file mode 100644 index 00000000..de3affe1 --- /dev/null +++ b/tests/integration/files/segmentEmployeesChanges.json @@ -0,0 +1,10 @@ +{ + "name": "employees", + "added": [ + "employee_3", + "employee_1" + ], + "removed": [], + "since": -1, + "till": 1457474612832 +} \ No newline at end of file diff --git a/tests/integration/files/segmentHumanBeignsChanges.json b/tests/integration/files/segmentHumanBeignsChanges.json new file mode 100644 index 00000000..c17b3ec3 --- /dev/null +++ b/tests/integration/files/segmentHumanBeignsChanges.json @@ -0,0 +1,10 @@ +{ + "name": "human_beigns", + "added": [ + "user1", + "user3" + ], + "removed": [], + "since": -1, + "till": 1457102183278 +} \ No newline at end of file diff --git a/tests/integration/files/splitChanges.json b/tests/integration/files/splitChanges.json new file mode 100644 index 00000000..d5401c93 --- /dev/null +++ b/tests/integration/files/splitChanges.json @@ -0,0 +1,321 @@ +{ + "splits": [ + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "whitelist_feature", + "seed": -1222652054, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "WHITELIST", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": { + "whitelist": [ + "whitelisted_user" + ] + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + } + ] + }, + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "all_feature", + "seed": 1699838640, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "killed_feature", + "seed": -480091424, + "status": "ACTIVE", + "killed": true, + "changeNumber": 123, + "defaultTreatment": "defTreatment", + "configurations": { + "off": "{\"size\":15,\"test\":20}", + "defTreatment": "{\"size\":15,\"defTreatment\":true}" + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "ALL_KEYS", + "negate": false, + "userDefinedSegmentMatcherData": null, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "defTreatment", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "sample_feature", + "seed": 1548363147, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "configurations": { + "on": "{\"size\":15,\"test\":20}" + }, + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "employees" + }, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + } + ] + }, + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SEGMENT", + "negate": false, + "userDefinedSegmentMatcherData": { + "segmentName": "human_beigns" + }, + "whitelistMatcherData": null + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 30 + }, + { + "treatment": "off", + "size": 70 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "dependency_test", + "seed": 1222652054, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "IN_SPLIT_TREATMENT", + "negate": false, + "userDefinedSegmentMatcherData": null, + "dependencyMatcherData": { + "split": "all_feature", + "treatments": ["on"] + } + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 0 + }, + { + "treatment": "off", + "size": 100 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "regex_test", + "seed": 1222652051, + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "MATCHES_STRING", + "negate": false, + "userDefinedSegmentMatcherData": null, + "stringMatcherData": "abc[0-9]" + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + }, + { + "orgId": null, + "environment": null, + "trafficTypeId": null, + "trafficTypeName": null, + "name": "boolean_test", + "status": "ACTIVE", + "killed": false, + "changeNumber": 123, + "seed": 12321809, + "defaultTreatment": "off", + "conditions": [ + { + "matcherGroup": { + "combiner": "AND", + "matchers": [ + { + "matcherType": "EQUAL_TO_BOOLEAN", + "negate": false, + "userDefinedSegmentMatcherData": null, + "booleanMatcherData": true + } + ] + }, + "partitions": [ + { + "treatment": "on", + "size": 100 + }, + { + "treatment": "off", + "size": 0 + } + ] + } + ] + } + ], + "since": -1, + "till": 1457726098069 +} diff --git a/tests/integration/test_client_e2e.py b/tests/integration/test_client_e2e.py new file mode 100644 index 00000000..0aa5b505 --- /dev/null +++ b/tests/integration/test_client_e2e.py @@ -0,0 +1,536 @@ +"""Client integration tests.""" +#pylint: disable=protected-access,line-too-long,no-self-use +import json +import os + +from redis import StrictRedis + +from splitio.client.factory import SplitFactory +from splitio.client.util import SdkMetadata +from splitio.storage.inmemmory import InMemoryEventStorage, InMemoryImpressionStorage, \ + InMemorySegmentStorage, InMemorySplitStorage, InMemoryTelemetryStorage +from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ + RedisSplitStorage, RedisSegmentStorage, RedisTelemetryStorage +from splitio.models import splits, segments + +class InMemoryIntegrationTests(object): + """Inmemory storage-based integration tests.""" + + def setup_method(self): + """Prepare storages with test data.""" + split_storage = InMemorySplitStorage() + segment_storage = InMemorySegmentStorage() + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['splits']: + split_storage.put(splits.from_raw(split)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + segment_storage.put(segments.from_raw(data)) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + segment_storage.put(segments.from_raw(data)) + + self.factory = SplitFactory({ #pylint:disable=attribute-defined-outside-init + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': InMemoryImpressionStorage(5000), + 'events': InMemoryEventStorage(5000), + 'telemetry': InMemoryTelemetryStorage() + }, True) + + def _validate_last_impressions(self, client, *to_validate): + """Validate the last N impressions are present disregarding the order.""" + imp_storage = client._factory._get_storage('impressions') + impressions = imp_storage.pop_many(len(to_validate)) + as_tup_set = set((i.feature_name, i.matching_key, i.treatment) for i in impressions) + assert as_tup_set == set(to_validate) + + def test_get_treatment(self): + """Test client.get_treatment().""" + client = self.factory.client() + + assert client.get_treatment('user1', 'sample_feature') == 'on' + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + assert client.get_treatment('invalidKey', 'sample_feature') == 'off' + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + assert client.get_treatment('invalidKey', 'invalid_feature') == 'control' + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + assert client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + assert client.get_treatment('invalidKey', 'all_feature') == 'on' + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing WHITELIST matcher + assert client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' + self._validate_last_impressions(client, ('whitelist_feature', 'whitelisted_user', 'on')) + assert client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' + self._validate_last_impressions(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) + + # testing INVALID matcher + assert client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' + self._validate_last_impressions(client, ('invalid_matcher_feature', 'some_user_key', 'control')) + + # testing Dependency matcher + assert client.get_treatment('somekey', 'dependency_test') == 'off' + self._validate_last_impressions(client, ('dependency_test', 'somekey', 'off')) + + # testing boolean matcher + assert client.get_treatment('True', 'boolean_test') == 'on' + self._validate_last_impressions(client, ('boolean_test', 'True', 'on')) + + # testing regex matcher + assert client.get_treatment('abc4', 'regex_test') == 'on' + self._validate_last_impressions(client, ('regex_test', 'abc4', 'on')) + + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + client = self.factory.client() + + result = client.get_treatment_with_config('user1', 'sample_feature') + assert result == ('on', '{"size":15,"test":20}') + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatment_with_config('invalidKey', 'sample_feature') + assert result == ('off', None) + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatment_with_config('invalidKey', 'invalid_feature') + assert result == ('control', None) + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatment_with_config('invalidKey', 'killed_feature') + assert ('defTreatment', '{"size":15,"defTreatment":true}') == result + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatment_with_config('invalidKey', 'all_feature') + assert result == ('on', None) + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + def test_get_treatments(self): + """Test client.get_treatments().""" + client = self.factory.client() + + result = client.get_treatments('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'on' + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatments('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'off' + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatments('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == 'control' + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == 'on' + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing multiple splitNames + result = client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + self._validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('invalid_feature', 'invalidKey', 'control'), + ('sample_feature', 'invalidKey', 'off') + ) + + def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + client = self.factory.client() + + result = client.get_treatments_with_config('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('on', '{"size":15,"test":20}') + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatments_with_config('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('off', None) + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatments_with_config('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == ('control', None) + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_with_config('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_with_config('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing multiple splitNames + result = client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + self._validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('invalid_feature', 'invalidKey', 'control'), + ('sample_feature', 'invalidKey', 'off'), + ) + + def test_manager_methods(self): + """Test manager.split/splits.""" + manager = self.factory.manager() + result = manager.split('all_feature') + assert result.name == 'all_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs == {} + + result = manager.split('killed_feature') + assert result.name == 'killed_feature' + assert result.traffic_type is None + assert result.killed is True + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' + assert result.configs['off'] == '{"size":15,"test":20}' + + result = manager.split('sample_feature') + assert result.name == 'sample_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['on'] == '{"size":15,"test":20}' + + assert len(manager.split_names()) == 7 + assert len(manager.splits()) == 7 + + +class RedisIntegrationTests(object): + """Inmemory storage-based integration tests.""" + + def setup_method(self): + """Prepare storages with test data.""" + metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') + redis_client = StrictRedis() + split_storage = RedisSplitStorage(redis_client) + segment_storage = RedisSegmentStorage(redis_client) + + split_fn = os.path.join(os.path.dirname(__file__), 'files', 'splitChanges.json') + with open(split_fn, 'r') as flo: + data = json.loads(flo.read()) + for split in data['splits']: + redis_client.set(split_storage._get_key(split['name']), json.dumps(split)) + redis_client.set(split_storage._SPLIT_TILL_KEY, data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentEmployeesChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) + redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + + segment_fn = os.path.join(os.path.dirname(__file__), 'files', 'segmentHumanBeignsChanges.json') + with open(segment_fn, 'r') as flo: + data = json.loads(flo.read()) + redis_client.sadd(segment_storage._get_key(data['name']), *data['added']) + redis_client.set(segment_storage._get_till_key(data['name']), data['till']) + + self.factory = SplitFactory({ #pylint:disable=attribute-defined-outside-init + 'splits': split_storage, + 'segments': segment_storage, + 'impressions': RedisImpressionsStorage(redis_client, metadata), + 'events': RedisEventsStorage(redis_client, metadata), + 'telemetry': RedisTelemetryStorage(redis_client, metadata) + }, True) + + def _validate_last_impressions(self, client, *to_validate): + """Validate the last N impressions are present disregarding the order.""" + imp_storage = client._factory._get_storage('impressions') + redis_client = imp_storage._redis + impressions_raw = [ + json.loads(redis_client.lpop(imp_storage.IMPRESSIONS_QUEUE_KEY)) + for _ in to_validate + ] + as_tup_set = set( + (i['i']['f'], i['i']['k'], i['i']['t']) + for i in impressions_raw + ) + + print set(to_validate) + print as_tup_set + assert as_tup_set == set(to_validate) + + def test_get_treatment(self): + """Test client.get_treatment().""" + client = self.factory.client() + + assert client.get_treatment('user1', 'sample_feature') == 'on' + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + assert client.get_treatment('invalidKey', 'sample_feature') == 'off' + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + assert client.get_treatment('invalidKey', 'invalid_feature') == 'control' + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + assert client.get_treatment('invalidKey', 'killed_feature') == 'defTreatment' + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + assert client.get_treatment('invalidKey', 'all_feature') == 'on' + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing WHITELIST matcher + assert client.get_treatment('whitelisted_user', 'whitelist_feature') == 'on' + self._validate_last_impressions(client, ('whitelist_feature', 'whitelisted_user', 'on')) + assert client.get_treatment('unwhitelisted_user', 'whitelist_feature') == 'off' + self._validate_last_impressions(client, ('whitelist_feature', 'unwhitelisted_user', 'off')) + + # testing INVALID matcher + assert client.get_treatment('some_user_key', 'invalid_matcher_feature') == 'control' + self._validate_last_impressions(client, ('invalid_matcher_feature', 'some_user_key', 'control')) + + # testing Dependency matcher + assert client.get_treatment('somekey', 'dependency_test') == 'off' + self._validate_last_impressions(client, ('dependency_test', 'somekey', 'off')) + + # testing boolean matcher + assert client.get_treatment('True', 'boolean_test') == 'on' + self._validate_last_impressions(client, ('boolean_test', 'True', 'on')) + + # testing regex matcher + assert client.get_treatment('abc4', 'regex_test') == 'on' + self._validate_last_impressions(client, ('regex_test', 'abc4', 'on')) + + def test_get_treatment_with_config(self): + """Test client.get_treatment_with_config().""" + client = self.factory.client() + + result = client.get_treatment_with_config('user1', 'sample_feature') + assert result == ('on', '{"size":15,"test":20}') + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatment_with_config('invalidKey', 'sample_feature') + assert result == ('off', None) + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatment_with_config('invalidKey', 'invalid_feature') + assert result == ('control', None) + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatment_with_config('invalidKey', 'killed_feature') + assert ('defTreatment', '{"size":15,"defTreatment":true}') == result + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatment_with_config('invalidKey', 'all_feature') + assert result == ('on', None) + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + def test_get_treatments(self): + """Test client.get_treatments().""" + client = self.factory.client() + + result = client.get_treatments('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'on' + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatments('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == 'off' + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatments('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == 'control' + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == 'defTreatment' + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == 'on' + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing multiple splitNames + result = client.get_treatments('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == 'on' + assert result['killed_feature'] == 'defTreatment' + assert result['invalid_feature'] == 'control' + assert result['sample_feature'] == 'off' + self._validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('invalid_feature', 'invalidKey', 'control'), + ('sample_feature', 'invalidKey', 'off') + ) + + def test_get_treatments_with_config(self): + """Test client.get_treatments_with_config().""" + client = self.factory.client() + + result = client.get_treatments_with_config('user1', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('on', '{"size":15,"test":20}') + self._validate_last_impressions(client, ('sample_feature', 'user1', 'on')) + + result = client.get_treatments_with_config('invalidKey', ['sample_feature']) + assert len(result) == 1 + assert result['sample_feature'] == ('off', None) + self._validate_last_impressions(client, ('sample_feature', 'invalidKey', 'off')) + + result = client.get_treatments_with_config('invalidKey', ['invalid_feature']) + assert len(result) == 1 + assert result['invalid_feature'] == ('control', None) + self._validate_last_impressions(client, ('invalid_feature', 'invalidKey', 'control')) + + # testing a killed feature. No matter what the key, must return default treatment + result = client.get_treatments_with_config('invalidKey', ['killed_feature']) + assert len(result) == 1 + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + self._validate_last_impressions(client, ('killed_feature', 'invalidKey', 'defTreatment')) + + # testing ALL matcher + result = client.get_treatments_with_config('invalidKey', ['all_feature']) + assert len(result) == 1 + assert result['all_feature'] == ('on', None) + self._validate_last_impressions(client, ('all_feature', 'invalidKey', 'on')) + + # testing multiple splitNames + result = client.get_treatments_with_config('invalidKey', [ + 'all_feature', + 'killed_feature', + 'invalid_feature', + 'sample_feature' + ]) + assert len(result) == 4 + assert result['all_feature'] == ('on', None) + assert result['killed_feature'] == ('defTreatment', '{"size":15,"defTreatment":true}') + assert result['invalid_feature'] == ('control', None) + assert result['sample_feature'] == ('off', None) + self._validate_last_impressions( + client, + ('all_feature', 'invalidKey', 'on'), + ('killed_feature', 'invalidKey', 'defTreatment'), + ('invalid_feature', 'invalidKey', 'control'), + ('sample_feature', 'invalidKey', 'off'), + ) + + def test_manager_methods(self): + """Test manager.split/splits.""" + manager = self.factory.manager() + result = manager.split('all_feature') + assert result.name == 'all_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs == {} + + result = manager.split('killed_feature') + assert result.name == 'killed_feature' + assert result.traffic_type is None + assert result.killed is True + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['defTreatment'] == '{"size":15,"defTreatment":true}' + assert result.configs['off'] == '{"size":15,"test":20}' + + result = manager.split('sample_feature') + assert result.name == 'sample_feature' + assert result.traffic_type is None + assert result.killed is False + assert len(result.treatments) == 2 + assert result.change_number == 123 + assert result.configs['on'] == '{"size":15,"test":20}' + + assert len(manager.split_names()) == 7 + assert len(manager.splits()) == 7 + + def teardown_method(self): + """Clear redis cache.""" + keys_to_delete = [ + "SPLITIO/python-1.2.3/some_ip/latency.sdk.getTreatment.bucket.0", + "SPLITIO.segment.human_beigns", + "SPLITIO.segment.employees.till", + "SPLITIO.split.sample_feature", + "SPLITIO.splits.till", + "SPLITIO.split.killed_feature", + "SPLITIO.split.all_feature", + "SPLITIO.split.whitelist_feature", + "SPLITIO.segment.employees", + "SPLITIO/python-1.2.3/some_ip/latency.sdk.getTreatments.bucket.0", + "SPLITIO.split.regex_test", + "SPLITIO.segment.human_beigns.till", + "SPLITIO.split.boolean_test", + "SPLITIO.split.dependency_test" + ] + + redis_client = StrictRedis() + for key in keys_to_delete: + redis_client.delete(key) + From 56e6bcad853a2a4b3d7ff581012d1ccf31ca1533 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 22 Apr 2019 16:21:56 -0300 Subject: [PATCH 17/38] remove unnecessary print --- splitio/storage/inmemmory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/splitio/storage/inmemmory.py b/splitio/storage/inmemmory.py index e8e0eccb..8a6012ab 100644 --- a/splitio/storage/inmemmory.py +++ b/splitio/storage/inmemmory.py @@ -241,7 +241,6 @@ def put(self, impressions): with self._lock: for impression in impressions: self._impressions.put(impression, False) - print self._impressions.qsize() return True except queue.Full: if self._queue_full_hook is not None and callable(self._queue_full_hook): From 0a45113762cc182cbff2a422934557b151afcafa Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 22 Apr 2019 16:57:36 -0300 Subject: [PATCH 18/38] update e2e tests --- tests/client/test_localhost.py | 28 ------------------------- tests/integration/files/file2.yaml | 18 ++++++++++++++++ tests/integration/test_client_e2e.py | 31 +++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 29 deletions(-) create mode 100644 tests/integration/files/file2.yaml diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py index ea10c783..5da30500 100644 --- a/tests/client/test_localhost.py +++ b/tests/client/test_localhost.py @@ -4,7 +4,6 @@ import os import tempfile -from splitio.client.factory import get_factory from splitio.client import localhost from splitio.models.splits import Split from splitio.models.grammar.matchers import AllKeysMatcher @@ -195,30 +194,3 @@ def test_update_splits(self, mocker): task._update_splits() assert parse_legacy.mock_calls == [mocker.call('yaml')] assert parse_yaml.mock_calls == [] - - def test_localhost_e2e(self): - """Instantiate a client with a YAML file and issue get_treatment() calls.""" - filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') - factory = get_factory('localhost', config={'splitFile': filename}) - client = factory.client() - assert client.get_treatment_with_config('key', 'my_feature') == ('on', '{"desc" : "this applies only to ON treatment"}') - assert client.get_treatment_with_config('only_key', 'my_feature') == ( - 'off', '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' - ) - assert client.get_treatment_with_config('another_key', 'my_feature') == ('control', None) - assert client.get_treatment_with_config('key2', 'other_feature') == ('on', None) - assert client.get_treatment_with_config('key3', 'other_feature') == ('on', None) - assert client.get_treatment_with_config('some_key', 'other_feature_2') == ('on', None) - assert client.get_treatment_with_config('key_whitelist', 'other_feature_3') == ('on', None) - assert client.get_treatment_with_config('any_other_key', 'other_feature_3') == ('off', None) - - manager = factory.manager() - assert manager.split('my_feature').configs == { - 'on': '{"desc" : "this applies only to ON treatment"}', - 'off': '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' - } - assert manager.split('other_feature').configs == {} - assert manager.split('other_feature_2').configs == {} - assert manager.split('other_feature_3').configs == {} - - diff --git a/tests/integration/files/file2.yaml b/tests/integration/files/file2.yaml new file mode 100644 index 00000000..bc9b7705 --- /dev/null +++ b/tests/integration/files/file2.yaml @@ -0,0 +1,18 @@ +- my_feature: + treatment: "on" + keys: "key" + config: "{\"desc\" : \"this applies only to ON treatment\"}" +- other_feature_3: + treatment: "off" +- my_feature: + treatment: "off" + keys: "only_key" + config: "{\"desc\" : \"this applies only to OFF and only for only_key. The rest will receive ON\"}" +- other_feature_3: + treatment: "on" + keys: "key_whitelist" +- other_feature: + treatment: "on" + keys: ["key2","key3"] +- other_feature_2: + treatment: "on" diff --git a/tests/integration/test_client_e2e.py b/tests/integration/test_client_e2e.py index 0aa5b505..048e31c0 100644 --- a/tests/integration/test_client_e2e.py +++ b/tests/integration/test_client_e2e.py @@ -5,7 +5,7 @@ from redis import StrictRedis -from splitio.client.factory import SplitFactory +from splitio.client.factory import get_factory, SplitFactory from splitio.client.util import SdkMetadata from splitio.storage.inmemmory import InMemoryEventStorage, InMemoryImpressionStorage, \ InMemorySegmentStorage, InMemorySplitStorage, InMemoryTelemetryStorage @@ -13,6 +13,7 @@ RedisSplitStorage, RedisSegmentStorage, RedisTelemetryStorage from splitio.models import splits, segments + class InMemoryIntegrationTests(object): """Inmemory storage-based integration tests.""" @@ -534,3 +535,31 @@ def teardown_method(self): for key in keys_to_delete: redis_client.delete(key) + +class LocalhostIntegrationTests(object): + """Client & Manager integration tests.""" + + def test_localhost_e2e(self): + """Instantiate a client with a YAML file and issue get_treatment() calls.""" + filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') + factory = get_factory('localhost', config={'splitFile': filename}) + client = factory.client() + assert client.get_treatment_with_config('key', 'my_feature') == ('on', '{"desc" : "this applies only to ON treatment"}') + assert client.get_treatment_with_config('only_key', 'my_feature') == ( + 'off', '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' + ) + assert client.get_treatment_with_config('another_key', 'my_feature') == ('control', None) + assert client.get_treatment_with_config('key2', 'other_feature') == ('on', None) + assert client.get_treatment_with_config('key3', 'other_feature') == ('on', None) + assert client.get_treatment_with_config('some_key', 'other_feature_2') == ('on', None) + assert client.get_treatment_with_config('key_whitelist', 'other_feature_3') == ('on', None) + assert client.get_treatment_with_config('any_other_key', 'other_feature_3') == ('off', None) + + manager = factory.manager() + assert manager.split('my_feature').configs == { + 'on': '{"desc" : "this applies only to ON treatment"}', + 'off': '{"desc" : "this applies only to OFF and only for only_key. The rest will receive ON"}' + } + assert manager.split('other_feature').configs == {} + assert manager.split('other_feature_2').configs == {} + assert manager.split('other_feature_3').configs == {} From 527b0f35535f91acf2b9920301bcd6ddc543879a Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 22 Apr 2019 22:57:04 -0300 Subject: [PATCH 19/38] fix tests in py3 --- splitio/client/config.py | 2 ++ splitio/client/factory.py | 2 +- splitio/client/util.py | 33 ++++++++++++++------- splitio/storage/adapters/redis.py | 7 +++++ splitio/storage/redis.py | 3 +- tests/client/files/file1.split | 16 ++-------- tests/client/test_localhost.py | 21 ++++++------- tests/client/test_utils.py | 32 ++++++++++++++++++++ tests/integration/test_client_e2e.py | 7 ++--- tests/integration/test_redis_integration.py | 10 +++---- tests/storage/test_redis.py | 10 +++---- 11 files changed, 90 insertions(+), 53 deletions(-) create mode 100644 tests/client/test_utils.py diff --git a/splitio/client/config.py b/splitio/client/config.py index d91b4f39..e2bd9ded 100644 --- a/splitio/client/config.py +++ b/splitio/client/config.py @@ -39,5 +39,7 @@ 'redisSslCertReqs': None, 'redisSslCaCerts': None, 'redisMaxConnections': None, + 'machineName': None, + 'machineIp': None, 'splitFile': os.path.join(os.path.expanduser('~'), '.split') } diff --git a/splitio/client/factory.py b/splitio/client/factory.py index ef229828..7aeb89b6 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -312,7 +312,7 @@ def _build_redis_factory(config): """Build and return a split factory with redis-based storage.""" cfg = DEFAULT_CONFIG.copy() cfg.update(config) - sdk_metadata = util.get_metadata() + sdk_metadata = util.get_metadata({}) redis_adapter = redis.build(config) storages = { 'splits': RedisSplitStorage(redis_adapter), diff --git a/splitio/client/util.py b/splitio/client/util.py index af51297f..acf6c32d 100644 --- a/splitio/client/util.py +++ b/splitio/client/util.py @@ -28,11 +28,21 @@ def _get_hostname(ip_address): return 'unknown' if ip_address == 'unknown' else 'ip-' + ip_address.replace('.', '-') -def get_metadata(*args, **kwargs): - """Gather SDK metadata and return a tuple with such info.""" +def get_metadata(config): + """ + Gather SDK metadata and return a tuple with such info. + + :param config: User supplied config augmented with defaults. + :type config: dict + + :return: SDK Metadata information. + :rtype: SdkMetadata + """ version = 'python-%s' % __version__ - ip_address = _get_ip() - hostname = _get_hostname(ip_address) + ip_from_config = config.get('machineIp') + machine_from_config = config.get('machineName') + ip_address = ip_from_config if ip_from_config is not None else _get_ip() + hostname = machine_from_config if machine_from_config is not None else _get_hostname(ip_address) return SdkMetadata(version, hostname, ip_address) @@ -46,9 +56,12 @@ def get_calls(classes_filter=None): :return: list of callers ordered by most recent first. :rtype: list(tuple(str, str)) """ - return [ - inspect.getframeinfo(frame[0]).function - for frame in inspect.stack() - if classes_filter is None - or 'self' in frame[0].f_locals and frame[0].f_locals['self'].__class__.__name__ in classes_filter #pylint: disable=line-too-long - ] + try: + return [ + inspect.getframeinfo(frame[0]).function + for frame in inspect.stack() + if classes_filter is None + or 'self' in frame[0].f_locals and frame[0].f_locals['self'].__class__.__name__ in classes_filter #pylint: disable=line-too-long + ] + except Exception: #pylint: disable=broad-except + return [] diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index 35d3547f..2b0de584 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -279,6 +279,13 @@ def ttl(self, key): except RedisError as exc: raise_from(RedisAdapterException('Error executing ttl operation'), exc) + def lpop(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return self._decorated.lpop(self._add_prefix(key)) + except RedisError as exc: + raise_from(RedisAdapterException('Error executing lpop operation'), exc) + def _build_default_client(config): #pylint: disable=too-many-locals """ diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py index d62b1146..39384ee6 100644 --- a/splitio/storage/redis.py +++ b/splitio/storage/redis.py @@ -110,6 +110,7 @@ def get_split_names(self): """ try: keys = self._redis.keys(self._get_key('*')) + for key in keys: return [key.replace(self._get_key(''), '') for key in keys] except RedisAdapterException: self._logger.error('Error fetching split names from storage') @@ -129,7 +130,6 @@ def get_all_splits(self): raw_splits = self._redis.mget(keys) for raw in raw_splits: try: - print(raw, type(raw)) to_return.append(splits.from_raw(json.loads(raw))) except ValueError: self._logger.error('Could not parse split. Skipping') @@ -224,7 +224,6 @@ def get_change_number(self, segment_name): """ try: stored_value = self._redis.get(self._get_till_key(segment_name)) - print('aaa', stored_value) return json.loads(stored_value) if stored_value is not None else None except RedisAdapterException: self._logger.error('Error fetching segment change number from storage') diff --git a/tests/client/files/file1.split b/tests/client/files/file1.split index 10af3644..064c9d79 100644 --- a/tests/client/files/file1.split +++ b/tests/client/files/file1.split @@ -1,14 +1,2 @@ -events_write_es on -events_routing sqs -impressions_routing sqs -workspaces_v1 on -create_org_with_workspace on -sqs_events_processing on -sqs_impressions_processing on -sqs_events_fetch on -sqs_impressions_fetch off -sqs_impressions_fetch_period 700 -sqs_impressions_fetch_threads 10 -sqs_events_fetch_period 500 -sqs_events_fetch_threads 5 - +split1 on +split2 off diff --git a/tests/client/test_localhost.py b/tests/client/test_localhost.py index 5da30500..7c1a42bf 100644 --- a/tests/client/test_localhost.py +++ b/tests/client/test_localhost.py @@ -105,18 +105,15 @@ def test_make_whitelist_condition(self): def test_parse_legacy_file(self): """Test that aprsing a legacy file works.""" - with tempfile.NamedTemporaryFile() as temp_flo: - temp_flo.write('split1 on\n') - temp_flo.write('split2 off\n') - temp_flo.flush() - splits = localhost.LocalhostSplitSynchronizationTask._read_splits_from_legacy_file(temp_flo.name) - assert len(splits) == 2 - for split in splits.values(): - assert isinstance(split, Split) - assert splits['split1'].name == 'split1' - assert splits['split2'].name == 'split2' - assert isinstance(splits['split1'].conditions[0].matchers[0], AllKeysMatcher) - assert isinstance(splits['split2'].conditions[0].matchers[0], AllKeysMatcher) + filename = os.path.join(os.path.dirname(__file__), 'files', 'file1.split') + splits = localhost.LocalhostSplitSynchronizationTask._read_splits_from_legacy_file(filename) + assert len(splits) == 2 + for split in splits.values(): + assert isinstance(split, Split) + assert splits['split1'].name == 'split1' + assert splits['split2'].name == 'split2' + assert isinstance(splits['split1'].conditions[0].matchers[0], AllKeysMatcher) + assert isinstance(splits['split2'].conditions[0].matchers[0], AllKeysMatcher) def test_parse_yaml_file(self): """Test that parsing a yaml file works.""" diff --git a/tests/client/test_utils.py b/tests/client/test_utils.py new file mode 100644 index 00000000..7504484f --- /dev/null +++ b/tests/client/test_utils.py @@ -0,0 +1,32 @@ +"""Split client utilities test module.""" +#pylint: disable=no-self-use,too-few-public-methods + +from splitio.client import util, config +from splitio.version import __version__ + +class ClientUtilsTests(object): + """Client utilities test cases.""" + + def test_get_metadata(self, mocker): + """Test the get_metadata function.""" + get_ip_mock = mocker.Mock() + get_host_mock = mocker.Mock() + mocker.patch('splitio.client.util._get_ip', new=get_ip_mock) + mocker.patch('splitio.client.util._get_hostname', new=get_host_mock) + + meta = util.get_metadata({'machineIp': 'some_ip', 'machineName': 'some_machine_name'}) + assert get_ip_mock.mock_calls == [] + assert get_host_mock.mock_calls == [] + assert meta.instance_ip == 'some_ip' + assert meta.instance_name == 'some_machine_name' + assert meta.sdk_version == 'python-' + __version__ + + meta = util.get_metadata(config.DEFAULT_CONFIG) + assert get_ip_mock.mock_calls == [mocker.call()] + assert get_host_mock.mock_calls == [mocker.call(mocker.ANY)] + + get_ip_mock.reset_mock() + get_host_mock.reset_mock() + meta = util.get_metadata({}) + assert get_ip_mock.mock_calls == [mocker.call()] + assert get_host_mock.mock_calls == [mocker.call(mocker.ANY)] diff --git a/tests/integration/test_client_e2e.py b/tests/integration/test_client_e2e.py index 048e31c0..eb8dc1c6 100644 --- a/tests/integration/test_client_e2e.py +++ b/tests/integration/test_client_e2e.py @@ -11,6 +11,7 @@ InMemorySegmentStorage, InMemorySplitStorage, InMemoryTelemetryStorage from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ RedisSplitStorage, RedisSegmentStorage, RedisTelemetryStorage +from splitio.storage.adapters.redis import RedisAdapter from splitio.models import splits, segments @@ -262,7 +263,7 @@ class RedisIntegrationTests(object): def setup_method(self): """Prepare storages with test data.""" metadata = SdkMetadata('python-1.2.3', 'some_ip', 'some_name') - redis_client = StrictRedis() + redis_client = RedisAdapter(StrictRedis()) split_storage = RedisSplitStorage(redis_client) segment_storage = RedisSegmentStorage(redis_client) @@ -306,8 +307,6 @@ def _validate_last_impressions(self, client, *to_validate): for i in impressions_raw ) - print set(to_validate) - print as_tup_set assert as_tup_set == set(to_validate) def test_get_treatment(self): @@ -531,7 +530,7 @@ def teardown_method(self): "SPLITIO.split.dependency_test" ] - redis_client = StrictRedis() + redis_client = RedisAdapter(StrictRedis()) for key in keys_to_delete: redis_client.delete(key) diff --git a/tests/integration/test_redis_integration.py b/tests/integration/test_redis_integration.py index 9ed4ea78..1bdfea38 100644 --- a/tests/integration/test_redis_integration.py +++ b/tests/integration/test_redis_integration.py @@ -1,11 +1,11 @@ """Redis storage end to end tests.""" -#pylint: disable=no-self-use,protected-access +#pylint: disable=no-self-use,protected-access,line-too-long,too-few-public-methods import json import os from splitio.client.util import get_metadata -from splitio.models import splits, segments, impressions, events, telemetry +from splitio.models import splits, impressions, events from splitio.storage.redis import RedisSplitStorage, RedisSegmentStorage, RedisImpressionsStorage, \ RedisEventsStorage, RedisTelemetryStorage from splitio.storage.adapters.redis import _build_default_client @@ -142,7 +142,7 @@ def test_put_fetch_contains(self): """Test storing and retrieving splits in redis.""" adapter = _build_default_client({}) try: - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisImpressionsStorage(adapter, metadata) storage.put([ impressions.Impression('key1', 'feature1', 'on', 'l1', 123456, 'b1', 321654), @@ -163,7 +163,7 @@ def test_put_fetch_contains(self): """Test storing and retrieving splits in redis.""" adapter = _build_default_client({}) try: - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisEventsStorage(adapter, metadata) storage.put([ events.Event('key1', 'user', 'purchase', 3.5, 123456), @@ -183,7 +183,7 @@ class TelemetryStorageTests(object): def test_put_fetch_contains(self): """Test storing and retrieving splits in redis.""" adapter = _build_default_client({}) - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisTelemetryStorage(adapter, metadata) try: diff --git a/tests/storage/test_redis.py b/tests/storage/test_redis.py index e7dfd50f..3b11485d 100644 --- a/tests/storage/test_redis.py +++ b/tests/storage/test_redis.py @@ -144,7 +144,7 @@ class RedisImpressionsStorageTests(object): #pylint: disable=too-few-public-met def test_add_impressions(self, mocker): """Test that adding impressions to storage works.""" adapter = mocker.Mock(spec=RedisAdapter) - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisImpressionsStorage(adapter, metadata) impressions = [ @@ -189,7 +189,7 @@ class RedisEventsStorageTests(object): #pylint: disable=too-few-public-methods def test_add_events(self, mocker): """Test that adding impressions to storage works.""" adapter = mocker.Mock(spec=RedisAdapter) - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisEventsStorage(adapter, metadata) @@ -238,7 +238,7 @@ class RedisTelemetryStorageTests(object): def test_inc_latency(self, mocker): """Test incrementing latency.""" adapter = mocker.Mock(spec=RedisAdapter) - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisTelemetryStorage(adapter, metadata) storage.inc_latency('some_latency', 0) @@ -256,7 +256,7 @@ def test_inc_latency(self, mocker): def test_inc_counter(self, mocker): """Test incrementing latency.""" adapter = mocker.Mock(spec=RedisAdapter) - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisTelemetryStorage(adapter, metadata) storage.inc_counter('some_counter_1') @@ -275,7 +275,7 @@ def test_inc_counter(self, mocker): def test_inc_gauge(self, mocker): """Test incrementing latency.""" adapter = mocker.Mock(spec=RedisAdapter) - metadata = get_metadata() + metadata = get_metadata({}) storage = RedisTelemetryStorage(adapter, metadata) storage.put_gauge('gauge1', 123) From cd33c26ab2ddb165ecf8b104a938ebcd33d1913b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Mon, 22 Apr 2019 22:59:03 -0300 Subject: [PATCH 20/38] forward config to metadata builder --- splitio/client/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 7aeb89b6..72dfa468 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -312,7 +312,7 @@ def _build_redis_factory(config): """Build and return a split factory with redis-based storage.""" cfg = DEFAULT_CONFIG.copy() cfg.update(config) - sdk_metadata = util.get_metadata({}) + sdk_metadata = util.get_metadata(config) redis_adapter = redis.build(config) storages = { 'splits': RedisSplitStorage(redis_adapter), From 72bd9b1f01c8bf8a3cc682e05f60bfc2f0708f8c Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 10:08:54 -0300 Subject: [PATCH 21/38] fix impression listener. move to utils.raise_from --- splitio/api/client.py | 2 +- splitio/api/events.py | 2 +- splitio/api/impressions.py | 2 +- splitio/api/segments.py | 2 +- splitio/api/splits.py | 2 +- splitio/api/telemetry.py | 9 ++++--- splitio/client/factory.py | 31 ++++++++++++++++++++--- splitio/client/listener.py | 11 +++++--- splitio/client/localhost.py | 2 +- splitio/factories.py | 2 ++ splitio/models/grammar/matchers/sets.py | 8 +++--- splitio/models/grammar/matchers/string.py | 6 ++--- splitio/storage/adapters/redis.py | 3 ++- splitio/storage/redis.py | 1 - splitio/tasks/uwsgi_wrappers.py | 23 ++++++++++------- tests/client/test_factory.py | 7 ++--- tests/client/test_input_validator.py | 2 +- 17 files changed, 77 insertions(+), 38 deletions(-) create mode 100644 splitio/factories.py diff --git a/splitio/api/client.py b/splitio/api/client.py index b322660b..fd3bb9b8 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -3,7 +3,7 @@ from collections import namedtuple -from six import raise_from +from future.utils import raise_from import requests HttpResponse = namedtuple('HttpResponse', ['status_code', 'body']) diff --git a/splitio/api/events.py b/splitio/api/events.py index 2c929634..59817d31 100644 --- a/splitio/api/events.py +++ b/splitio/api/events.py @@ -1,7 +1,7 @@ """Events API module.""" import logging -from six import raise_from +from future.utils import raise_from from splitio.api import APIException from splitio.api.client import HttpClientException diff --git a/splitio/api/impressions.py b/splitio/api/impressions.py index e6b86213..2c8ee54e 100644 --- a/splitio/api/impressions.py +++ b/splitio/api/impressions.py @@ -3,7 +3,7 @@ import logging from itertools import groupby -from six import raise_from +from future.utils import raise_from from splitio.api import APIException from splitio.api.client import HttpClientException diff --git a/splitio/api/segments.py b/splitio/api/segments.py index 7cce297b..82f4a65a 100644 --- a/splitio/api/segments.py +++ b/splitio/api/segments.py @@ -3,7 +3,7 @@ import json import logging -from six import raise_from +from future.utils import raise_from from splitio.api import APIException from splitio.api.client import HttpClientException diff --git a/splitio/api/splits.py b/splitio/api/splits.py index 9c3af15a..53ee0ae9 100644 --- a/splitio/api/splits.py +++ b/splitio/api/splits.py @@ -3,7 +3,7 @@ import logging import json -from six import raise_from +from future.utils import raise_from from splitio.api import APIException from splitio.api.client import HttpClientException diff --git a/splitio/api/telemetry.py b/splitio/api/telemetry.py index c333b11d..d9fa107f 100644 --- a/splitio/api/telemetry.py +++ b/splitio/api/telemetry.py @@ -1,6 +1,9 @@ """Telemetry API Module.""" import logging + import six +from future.utils import raise_from + from splitio.api import APIException from splitio.api.client import HttpClientException @@ -62,7 +65,7 @@ def flush_latencies(self, latencies): except HttpClientException as exc: self._logger.error('Http client is throwing exceptions') self._logger.debug('Error: ', exc_info=True) - six.raise_from(APIException('Latencies not flushed correctly.'), exc) + raise_from(APIException('Latencies not flushed correctly.'), exc) @staticmethod def _build_gauges(gauges): @@ -98,7 +101,7 @@ def flush_gauges(self, gauges): except HttpClientException as exc: self._logger.error('Http client is throwing exceptions') self._logger.debug('Error: ', exc_info=True) - six.raise_from(APIException('Gauges not flushed correctly.'), exc) + raise_from(APIException('Gauges not flushed correctly.'), exc) @staticmethod def _build_counters(counters): @@ -134,4 +137,4 @@ def flush_counters(self, counters): except HttpClientException as exc: self._logger.error('Http client is throwing exceptions') self._logger.debug('Error: ', exc_info=True) - six.raise_from(APIException('Counters not flushed correctly.'), exc) + raise_from(APIException('Counters not flushed correctly.'), exc) diff --git a/splitio/client/factory.py b/splitio/client/factory.py index 72dfa468..fd4a4c60 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -13,6 +13,7 @@ from splitio.client.manager import SplitManager from splitio.client.config import DEFAULT_CONFIG from splitio.client import util +from splitio.client.listener import ImpressionListenerWrapper #Storage from splitio.storage.inmemmory import InMemorySplitStorage, InMemorySegmentStorage, \ @@ -208,6 +209,20 @@ def destroyed(self): return self._status == Status.DESTROYED +def _wrap_impression_listener(listener, metadata): + """ + Wrap the impression listener if any. + + :param listener: User supplied impression listener or None + :type listener: splitio.client.listener.ImpressionListener | None + :param metadata: SDK Metadata + :type metadata: splitio.client.util.SdkMetadata + """ + if listener is not None: + return ImpressionListenerWrapper(listener, metadata) + return None + + def _build_in_memory_factory(api_key, config, sdk_url=None, events_url=None): #pylint: disable=too-many-locals """Build and return a split factory tailored to the supplied config.""" if not input_validator.validate_factory_instantiation(api_key): @@ -305,7 +320,14 @@ def segment_ready_task(): segment_completion_thread = threading.Thread(target=segment_ready_task) segment_completion_thread.setDaemon(True) segment_completion_thread.start() - return SplitFactory(storages, cfg['labelsEnabled'], apis, tasks, sdk_ready_flag) + return SplitFactory( + storages, + cfg['labelsEnabled'], + apis, + tasks, + sdk_ready_flag, + impression_listener=_wrap_impression_listener(cfg['impressionListener'], sdk_metadata) + ) def _build_redis_factory(config): @@ -324,7 +346,7 @@ def _build_redis_factory(config): return SplitFactory( storages, cfg['labelsEnabled'], - impression_listener=cfg['impressionListener'] + impression_listener=_wrap_impression_listener(cfg['impressionListener'], sdk_metadata) ) @@ -332,6 +354,7 @@ def _build_uwsgi_factory(config): """Build and return a split factory with redis-based storage.""" cfg = DEFAULT_CONFIG.copy() cfg.update(config) + sdk_metadata = util.get_metadata(cfg) uwsgi_adapter = get_uwsgi() storages = { 'splits': UWSGISplitStorage(uwsgi_adapter), @@ -343,7 +366,7 @@ def _build_uwsgi_factory(config): return SplitFactory( storages, cfg['labelsEnabled'], - impression_listener=cfg['impressionListener'] + impression_listener=_wrap_impression_listener(cfg['impressionListener'], sdk_metadata) ) @@ -380,7 +403,7 @@ def get_factory(api_key, **kwargs): if 'redisHost' in config: return _build_redis_factory(config) - if 'uwsgiCache' in config: + if 'uwsgiClient' in config: return _build_uwsgi_factory(config) return _build_in_memory_factory( diff --git a/splitio/client/listener.py b/splitio/client/listener.py index 260d1969..abe78d1a 100644 --- a/splitio/client/listener.py +++ b/splitio/client/listener.py @@ -3,6 +3,7 @@ import abc from six import add_metaclass +from future.utils import raise_from class ImpressionListenerException(Exception): @@ -49,9 +50,13 @@ def log_impression(self, impression, attributes=None): data['instance-id'] = self._metadata.instance_name try: self.impression_listener.log_impression(data) - except Exception: - raise ImpressionListenerException('Error in log_impression user\'s' - 'method is throwing exceptions') + except Exception as exc: #pylint: disable=broad-except + import traceback + traceback.print_exc() + raise_from( + ImpressionListenerException('Error in log_impression user\'s method is throwing exceptions'), + exc + ) @add_metaclass(abc.ABCMeta) #pylint: disable=too-few-public-methods class ImpressionListener(object): diff --git a/splitio/client/localhost.py b/splitio/client/localhost.py index 7ee2ea03..4e702223 100644 --- a/splitio/client/localhost.py +++ b/splitio/client/localhost.py @@ -4,7 +4,7 @@ import logging import re -from six import raise_from +from future.utils import raise_from import yaml from splitio.models import splits diff --git a/splitio/factories.py b/splitio/factories.py new file mode 100644 index 00000000..2616581f --- /dev/null +++ b/splitio/factories.py @@ -0,0 +1,2 @@ +"""Backwards compatibility module.""" +from splitio.client.factory import get_factory diff --git a/splitio/models/grammar/matchers/sets.py b/splitio/models/grammar/matchers/sets.py index 4d10c4bd..7c8dfa77 100644 --- a/splitio/models/grammar/matchers/sets.py +++ b/splitio/models/grammar/matchers/sets.py @@ -45,7 +45,7 @@ def _add_matcher_specific_properties_to_json(self): """Return ContainsAllOfSet specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(list(self._whitelist)) } } @@ -95,7 +95,7 @@ def _add_matcher_specific_properties_to_json(self): """Return ContainsAnyOfSet specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(self._whitelist) } } @@ -145,7 +145,7 @@ def _add_matcher_specific_properties_to_json(self): """Return EqualToSet specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(self._whitelist) } } @@ -196,7 +196,7 @@ def _add_matcher_specific_properties_to_json(self): """Return PartOfSet specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(self._whitelist) } } diff --git a/splitio/models/grammar/matchers/string.py b/splitio/models/grammar/matchers/string.py index 9688aacc..bb75b02e 100644 --- a/splitio/models/grammar/matchers/string.py +++ b/splitio/models/grammar/matchers/string.py @@ -126,7 +126,7 @@ def _add_matcher_specific_properties_to_json(self): """Return StartsWith specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(self._whitelist) } } @@ -174,7 +174,7 @@ def _add_matcher_specific_properties_to_json(self): """Return EndsWith specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(self._whitelist) } } @@ -222,7 +222,7 @@ def _add_matcher_specific_properties_to_json(self): """Return ContainsString specific properties.""" return { 'whitelistMatcherData': { - 'whitelist': self._whitelist + 'whitelist': list(self._whitelist) } } diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index 2b0de584..fc31490b 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -3,7 +3,8 @@ unicode_literals from builtins import str -from six import string_types, binary_type, raise_from +from six import string_types, binary_type +from future.utils import raise_from from splitio.exceptions import SentinelConfigurationException try: diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py index 39384ee6..330df511 100644 --- a/splitio/storage/redis.py +++ b/splitio/storage/redis.py @@ -110,7 +110,6 @@ def get_split_names(self): """ try: keys = self._redis.keys(self._get_key('*')) - for key in keys: return [key.replace(self._get_key(''), '') for key in keys] except RedisAdapterException: self._logger.error('Error fetching split names from storage') diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py index 90c40746..c385a580 100644 --- a/splitio/tasks/uwsgi_wrappers.py +++ b/splitio/tasks/uwsgi_wrappers.py @@ -33,7 +33,7 @@ def _get_config(user_config): :return: Calculated configuration. :rtype: dict """ - sdk_config = DEFAULT_CONFIG + sdk_config = DEFAULT_CONFIG.copy() sdk_config.update(user_config) return sdk_config @@ -50,9 +50,9 @@ def uwsgi_update_splits(user_config): seconds = config['featuresRefreshRate'] split_sync_task = SplitSynchronizationTask( SplitsAPI( - HttpClient(config.get('sdk_url'), config.get('events_url')), config['apikey'] + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'] ), - UWSGISplitStorage(get_uwsgi), + UWSGISplitStorage(get_uwsgi()), None, # Time not needed since the task will be triggered manually. None # Ready flag not needed since it will never be set and consumed. ) @@ -62,6 +62,7 @@ def uwsgi_update_splits(user_config): time.sleep(seconds) except Exception: #pylint: disable=broad-except _LOGGER.error('Error updating splits') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_update_segments(user_config): @@ -76,7 +77,7 @@ def uwsgi_update_segments(user_config): seconds = config['segmentsRefreshRate'] segment_sync_task = SegmentSynchronizationTask( SegmentsAPI( - HttpClient(config.get('sdk_url'), config.get('events_url')), config['apikey'] + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'] ), UWSGISegmentStorage(get_uwsgi()), None, # Split sotrage not needed, segments provided manually, @@ -91,6 +92,7 @@ def uwsgi_update_segments(user_config): time.sleep(seconds) except Exception: #pylint: disable=broad-except _LOGGER.error('Error updating segments') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_report_impressions(user_config): @@ -107,13 +109,13 @@ def uwsgi_report_impressions(user_config): storage = UWSGIImpressionStorage(get_uwsgi()) impressions_sync_task = ImpressionsSyncTask( ImpressionsAPI( - HttpClient(config.get('sdk_url'), config.get('events_url')), + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'], metadata ), storage, None, # Period not needed. Task is being triggered manually. - 5000 # TODO: Parametrize! + config['impressionsRefreshRate'] ) while True: @@ -125,6 +127,7 @@ def uwsgi_report_impressions(user_config): time.sleep(1) except Exception: #pylint: disable=broad-except _LOGGER.error('Error posting impressions') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_report_events(user_config): """ @@ -140,13 +143,13 @@ def uwsgi_report_events(user_config): storage = UWSGIEventStorage(get_uwsgi()) task = EventsSyncTask( EventsAPI( - HttpClient(config.get('sdk_url'), config.get('events_url')), + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'], metadata ), storage, None, # Period not needed. Task is being triggered manually. - 5000 # TODO: Parametrize + config['eventsPushRate'] ) while True: task._send_events() #pylint: disable=protected-access @@ -157,6 +160,7 @@ def uwsgi_report_events(user_config): time.sleep(1) except Exception: #pylint: disable=broad-except _LOGGER.error('Error posting metrics') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_report_telemetry(user_config): """ @@ -172,7 +176,7 @@ def uwsgi_report_telemetry(user_config): storage = UWSGITelemetryStorage(get_uwsgi()) task = TelemetrySynchronizationTask( TelemetryAPI( - HttpClient(config.get('sdk_url'), config.get('events_url')), + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'], metadata ), @@ -184,3 +188,4 @@ def uwsgi_report_telemetry(user_config): time.sleep(seconds) except Exception: #pylint: disable=broad-except _LOGGER.error('Error posting metrics') + _LOGGER.debug('Error: ', exc_info=True) diff --git a/tests/client/test_factory.py b/tests/client/test_factory.py index bb5a1339..1888ad52 100644 --- a/tests/client/test_factory.py +++ b/tests/client/test_factory.py @@ -3,6 +3,7 @@ import time import threading +from splitio.client.listener import ImpressionListenerWrapper from splitio.client.factory import get_factory from splitio.client.config import DEFAULT_CONFIG from splitio.storage import redis, inmemmory, uwsgi @@ -159,7 +160,7 @@ def test_redis_client_creation(self, mocker): max_connections=999 )] assert factory._labels_enabled is False - assert factory._impression_listener == 123 + assert isinstance(factory._impression_listener, ImpressionListenerWrapper) factory.block_until_ready() time.sleep(1) # give a chance for the bg thread to set the ready status assert factory.ready @@ -167,7 +168,7 @@ def test_redis_client_creation(self, mocker): def test_uwsgi_client_creation(self): """Test that a client with redis storage is created correctly.""" - factory = get_factory('some_api_key', config={'uwsgiCache': True, 'impressionListener': 123}) + factory = get_factory('some_api_key', config={'uwsgiClient': True}) assert isinstance(factory._get_storage('splits'), uwsgi.UWSGISplitStorage) assert isinstance(factory._get_storage('segments'), uwsgi.UWSGISegmentStorage) assert isinstance(factory._get_storage('impressions'), uwsgi.UWSGIImpressionStorage) @@ -176,7 +177,7 @@ def test_uwsgi_client_creation(self): assert factory._apis == {} assert factory._tasks == {} assert factory._labels_enabled is True - assert factory._impression_listener == 123 + assert factory._impression_listener is None factory.block_until_ready() time.sleep(1) # give a chance for the bg thread to set the ready status assert factory.ready diff --git a/tests/client/test_input_validator.py b/tests/client/test_input_validator.py index 0a065db1..91b468f8 100644 --- a/tests/client/test_input_validator.py +++ b/tests/client/test_input_validator.py @@ -885,7 +885,7 @@ def test_input_validation_factory(self, mocker): ] logger.reset_mock() - assert get_factory(True, config={'uwsgiCache': True}) is not None + assert get_factory(True, config={'uwsgiClient': True}) is not None assert logger.error.mock_calls == [] logger.reset_mock() From 0f1033d0966452fad67ca3382f6e1a034803f4a1 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 10:24:48 -0300 Subject: [PATCH 22/38] reorganize exceptions --- splitio/exceptions.py | 13 ++----------- splitio/storage/adapters/redis.py | 5 ++++- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/splitio/exceptions.py b/splitio/exceptions.py index 0af817de..0c633d33 100644 --- a/splitio/exceptions.py +++ b/splitio/exceptions.py @@ -1,14 +1,5 @@ """This module contains everything related to split.io exceptions""" from __future__ import absolute_import, division, print_function, unicode_literals - -class TimeoutException(Exception): - pass - - -class NetworkingException(Exception): - pass - - -class SentinelConfigurationException(Exception): - pass +from splitio.client.factory import TimeoutException +from splitio.storage.adapters.redis import SentinelConfigurationException diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index fc31490b..d1b54171 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -5,7 +5,6 @@ from builtins import str from six import string_types, binary_type from future.utils import raise_from -from splitio.exceptions import SentinelConfigurationException try: from redis import StrictRedis @@ -48,6 +47,10 @@ def original_exception(self): return self._original_exception +class SentinelConfigurationException(Exception): + pass + + class RedisAdapter(object): #pylint: disable=too-many-public-methods """ Instance decorator for Redis clients such as StrictRedis. From 3a8d4ac98015d4c3e438789e77447c8f6bcfeb2a Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 11:29:17 -0300 Subject: [PATCH 23/38] conver mget result to list of strings --- splitio/storage/adapters/redis.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index d1b54171..8fa29eb6 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -185,7 +185,10 @@ def lrange(self, key, start, end): def mget(self, names): """Mimic original redis function but using user custom prefix.""" try: - return _bytes_to_string(self._decorated.mget(self._add_prefix(names))) + return [ + _bytes_to_string(item) + for item in self._decorated.mget(self._add_prefix(names)) + ] except RedisError as exc: raise_from(RedisAdapterException('Error executing mget operation'), exc) From add69ec5b3cdd5af19fffb7883928dd32b13d489 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 12:45:02 -0300 Subject: [PATCH 24/38] fix track issue in redis, fix tests --- splitio/storage/redis.py | 2 +- tests/storage/adapters/test_redis_adapter.py | 1 + tests/storage/test_redis.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py index 330df511..c0b6d85a 100644 --- a/splitio/storage/redis.py +++ b/splitio/storage/redis.py @@ -370,7 +370,7 @@ def put(self, events): key = self._KEY_TEMPLATE to_store = [ json.dumps({ - 'i': { + 'e': { 'key': event.key, 'trafficTypeName': event.traffic_type_name, 'eventTypeId': event.event_type_id, diff --git a/tests/storage/adapters/test_redis_adapter.py b/tests/storage/adapters/test_redis_adapter.py index 9d5253f0..7336ce99 100644 --- a/tests/storage/adapters/test_redis_adapter.py +++ b/tests/storage/adapters/test_redis_adapter.py @@ -29,6 +29,7 @@ def test_forwarding(self, mocker): adapter.delete('some_key') assert redis_mock.delete.mock_calls[0] == mocker.call('some_prefix.some_key') + redis_mock.mget.return_value = ['value1', 'value2', 'value3'] adapter.mget(['key1', 'key2', 'key3']) assert redis_mock.mget.mock_calls[0] == mocker.call(['some_prefix.key1', 'some_prefix.key2', 'some_prefix.key3']) diff --git a/tests/storage/test_redis.py b/tests/storage/test_redis.py index 3b11485d..d5bbb06d 100644 --- a/tests/storage/test_redis.py +++ b/tests/storage/test_redis.py @@ -207,7 +207,7 @@ def test_add_events(self, mocker): 'n': metadata.instance_name, 'i': metadata.instance_ip, }, - 'i': { # IMPRESSION PORTION + 'e': { # EVENT PORTION 'key': event.key, 'trafficTypeName': event.traffic_type_name, 'eventTypeId': event.event_type_id, From 81ee307db3cdfba057f6ce9fc2a956ab83381753 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 13:43:33 -0300 Subject: [PATCH 25/38] look for sentinels as well when guessing operation mode --- splitio/client/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splitio/client/factory.py b/splitio/client/factory.py index fd4a4c60..d3198552 100644 --- a/splitio/client/factory.py +++ b/splitio/client/factory.py @@ -400,7 +400,7 @@ def get_factory(api_key, **kwargs): if api_key == 'localhost': return _build_localhost_factory(config) - if 'redisHost' in config: + if 'redisHost' in config or 'redisSentinels' in config: return _build_redis_factory(config) if 'uwsgiClient' in config: From 7b06c624cc4efb538c73a6e42f2b6f22fe3d9d4b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 15:37:30 -0300 Subject: [PATCH 26/38] use a pool in uwsgi mode as well to increase segment fetching throughput --- splitio/tasks/uwsgi_wrappers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py index c385a580..9a2a11f0 100644 --- a/splitio/tasks/uwsgi_wrappers.py +++ b/splitio/tasks/uwsgi_wrappers.py @@ -2,6 +2,7 @@ import logging import time + from splitio.client.config import DEFAULT_CONFIG from splitio.client.util import get_metadata from splitio.storage.adapters.uwsgi_cache import get_uwsgi @@ -18,7 +19,7 @@ from splitio.tasks.impressions_sync import ImpressionsSyncTask from splitio.tasks.events_sync import EventsSyncTask from splitio.tasks.telemetry_sync import TelemetrySynchronizationTask - +from splitio.tasks.util import workerpool _LOGGER = logging.getLogger(__name__) @@ -84,11 +85,12 @@ def uwsgi_update_segments(user_config): None, # Period not needed, task executed manually None # Flag not needed, never consumed or set. ) + + pool = workerpool.WorkerPool(20, segment_sync_task._update_segment) #pylint: disable=protected-access split_storage = UWSGISplitStorage(get_uwsgi()) while True: - segment_names = split_storage.get_segment_names() - for segment_name in segment_names: - segment_sync_task._update_segment(segment_name) #pylint: disable=protected-access + for name in split_storage.get_split_names(): + pool.submit_work(name) time.sleep(seconds) except Exception: #pylint: disable=broad-except _LOGGER.error('Error updating segments') From 605908bad526ca24673860abede609ac1bdc8025 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 16:14:22 -0300 Subject: [PATCH 27/38] more uwsgi optimizations --- splitio/storage/uwsgi.py | 45 +++++++++++++++++---------------- splitio/tasks/uwsgi_wrappers.py | 4 +-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/splitio/storage/uwsgi.py b/splitio/storage/uwsgi.py index 737049ea..d9adbb37 100644 --- a/splitio/storage/uwsgi.py +++ b/splitio/storage/uwsgi.py @@ -166,7 +166,6 @@ class UWSGISegmentStorage(SegmentStorage): _KEY_TEMPLATE = 'segments.{suffix}' _SEGMENT_DATA_KEY_TEMPLATE = 'segmentData.{segment_name}' _SEGMENT_CHANGE_NUMBER_KEY_TEMPLATE = 'segment.{segment_name}.till' - _SEGMENT_REGISTERED = _KEY_TEMPLATE.format(suffix='registered') def __init__(self, uwsgi_entrypoint): """ @@ -319,6 +318,7 @@ def put(self, impressions): :param impressions: List of one or more impressions to store. :type impressions: list """ + to_store = [i._asdict() for i in impressions] with UWSGILock(self._uwsgi, self._LOCK_IMPRESSION_KEY): try: current = json.loads(self._uwsgi.cache_get( @@ -329,7 +329,7 @@ def put(self, impressions): self._uwsgi.cache_update( self._IMPRESSIONS_KEY, - json.dumps(current + [i._asdict() for i in impressions]), + json.dumps(current + to_store), 0, _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE ) @@ -356,17 +356,17 @@ def pop_many(self, count): _SPLITIO_IMPRESSIONS_CACHE_NAMESPACE ) - return [ - Impression( - impression['matching_key'], - impression['feature_name'], - impression['treatment'], - impression['label'], - impression['change_number'], - impression['bucketing_key'], - impression['time'] - ) for impression in current[:count] - ] + return [ + Impression( + impression['matching_key'], + impression['feature_name'], + impression['treatment'], + impression['label'], + impression['change_number'], + impression['bucketing_key'], + impression['time'] + ) for impression in current[:count] + ] def request_flush(self): """Set a marker in the events cache to indicate that a flush has been requested.""" @@ -447,15 +447,16 @@ def pop_many(self, count): 0, _SPLITIO_EVENTS_CACHE_NAMESPACE ) - return [ - Event( - event['key'], - event['traffic_type_name'], - event['event_type_id'], - event['value'], - event['timestamp'] - ) for event in current[:count] - ] + + return [ + Event( + event['key'], + event['traffic_type_name'], + event['event_type_id'], + event['value'], + event['timestamp'] + ) for event in current[:count] + ] def request_flush(self): """Set a marker in the events cache to indicate that a flush has been requested.""" diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py index 9a2a11f0..49c8a6de 100644 --- a/splitio/tasks/uwsgi_wrappers.py +++ b/splitio/tasks/uwsgi_wrappers.py @@ -117,7 +117,7 @@ def uwsgi_report_impressions(user_config): ), storage, None, # Period not needed. Task is being triggered manually. - config['impressionsRefreshRate'] + config['impressionsBulkSize'] ) while True: @@ -151,7 +151,7 @@ def uwsgi_report_events(user_config): ), storage, None, # Period not needed. Task is being triggered manually. - config['eventsPushRate'] + config['eventsBulkSize'] ) while True: task._send_events() #pylint: disable=protected-access From 4723b9f58515b563994e1a5cc1850149b06372b5 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 16:41:14 -0300 Subject: [PATCH 28/38] start the pool --- splitio/tasks/uwsgi_wrappers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py index 49c8a6de..6ab2ac84 100644 --- a/splitio/tasks/uwsgi_wrappers.py +++ b/splitio/tasks/uwsgi_wrappers.py @@ -87,6 +87,7 @@ def uwsgi_update_segments(user_config): ) pool = workerpool.WorkerPool(20, segment_sync_task._update_segment) #pylint: disable=protected-access + pool.start() split_storage = UWSGISplitStorage(get_uwsgi()) while True: for name in split_storage.get_split_names(): From 04b1995d551d201fc12199fefc1e397f456fcd41 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 17:08:57 -0300 Subject: [PATCH 29/38] fix segment fetching in uwsgi mode --- splitio/tasks/uwsgi_wrappers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py index 6ab2ac84..15c7c76a 100644 --- a/splitio/tasks/uwsgi_wrappers.py +++ b/splitio/tasks/uwsgi_wrappers.py @@ -90,8 +90,9 @@ def uwsgi_update_segments(user_config): pool.start() split_storage = UWSGISplitStorage(get_uwsgi()) while True: - for name in split_storage.get_split_names(): - pool.submit_work(name) + for split in split_storage.get_all_splits(): + for segment_name in split.get_segment_names(): + pool.submit_work(segment_name) time.sleep(seconds) except Exception: #pylint: disable=broad-except _LOGGER.error('Error updating segments') From 9d3157ce6fa74d6007511c3f88f784b25cc50169 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 23 Apr 2019 18:09:49 -0300 Subject: [PATCH 30/38] do not fail when removing nonexistant split in uwsgi. Keep tasks going after failures --- splitio/storage/uwsgi.py | 3 + splitio/tasks/uwsgi_wrappers.py | 186 ++++++++++++++++---------------- 2 files changed, 96 insertions(+), 93 deletions(-) diff --git a/splitio/storage/uwsgi.py b/splitio/storage/uwsgi.py index d9adbb37..e6835f54 100644 --- a/splitio/storage/uwsgi.py +++ b/splitio/storage/uwsgi.py @@ -105,6 +105,9 @@ def remove(self, split_name): except TypeError: # Split list not found, no need to delete anything pass + except KeyError: + # Split not found in list. nothing to do. + pass result = self._uwsgi.cache_del( self._KEY_TEMPLATE.format(suffix=split_name), diff --git a/splitio/tasks/uwsgi_wrappers.py b/splitio/tasks/uwsgi_wrappers.py index 15c7c76a..54d56a88 100644 --- a/splitio/tasks/uwsgi_wrappers.py +++ b/splitio/tasks/uwsgi_wrappers.py @@ -46,24 +46,24 @@ def uwsgi_update_splits(user_config): :param user_config: User-provided configuration. :type user_config: dict """ - try: - config = _get_config(user_config) - seconds = config['featuresRefreshRate'] - split_sync_task = SplitSynchronizationTask( - SplitsAPI( - HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'] - ), - UWSGISplitStorage(get_uwsgi()), - None, # Time not needed since the task will be triggered manually. - None # Ready flag not needed since it will never be set and consumed. - ) - - while True: + config = _get_config(user_config) + seconds = config['featuresRefreshRate'] + split_sync_task = SplitSynchronizationTask( + SplitsAPI( + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'] + ), + UWSGISplitStorage(get_uwsgi()), + None, # Time not needed since the task will be triggered manually. + None # Ready flag not needed since it will never be set and consumed. + ) + + while True: + try: split_sync_task._update_splits() #pylint: disable=protected-access time.sleep(seconds) - except Exception: #pylint: disable=broad-except - _LOGGER.error('Error updating splits') - _LOGGER.debug('Error: ', exc_info=True) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error updating splits') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_update_segments(user_config): @@ -73,30 +73,30 @@ def uwsgi_update_segments(user_config): :param user_config: User-provided configuration. :type user_config: dict """ - try: - config = _get_config(user_config) - seconds = config['segmentsRefreshRate'] - segment_sync_task = SegmentSynchronizationTask( - SegmentsAPI( - HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'] - ), - UWSGISegmentStorage(get_uwsgi()), - None, # Split sotrage not needed, segments provided manually, - None, # Period not needed, task executed manually - None # Flag not needed, never consumed or set. - ) - - pool = workerpool.WorkerPool(20, segment_sync_task._update_segment) #pylint: disable=protected-access - pool.start() - split_storage = UWSGISplitStorage(get_uwsgi()) - while True: + config = _get_config(user_config) + seconds = config['segmentsRefreshRate'] + segment_sync_task = SegmentSynchronizationTask( + SegmentsAPI( + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), config['apikey'] + ), + UWSGISegmentStorage(get_uwsgi()), + None, # Split sotrage not needed, segments provided manually, + None, # Period not needed, task executed manually + None # Flag not needed, never consumed or set. + ) + + pool = workerpool.WorkerPool(20, segment_sync_task._update_segment) #pylint: disable=protected-access + pool.start() + split_storage = UWSGISplitStorage(get_uwsgi()) + while True: + try: for split in split_storage.get_all_splits(): for segment_name in split.get_segment_names(): pool.submit_work(segment_name) time.sleep(seconds) - except Exception: #pylint: disable=broad-except - _LOGGER.error('Error updating segments') - _LOGGER.debug('Error: ', exc_info=True) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error updating segments') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_report_impressions(user_config): @@ -106,32 +106,32 @@ def uwsgi_report_impressions(user_config): :param user_config: User-provided configuration. :type user_config: dict """ - try: - config = _get_config(user_config) - metadata = get_metadata(config) - seconds = config['impressionsRefreshRate'] - storage = UWSGIImpressionStorage(get_uwsgi()) - impressions_sync_task = ImpressionsSyncTask( - ImpressionsAPI( - HttpClient(1500, config.get('sdk_url'), config.get('events_url')), - config['apikey'], - metadata - ), - storage, - None, # Period not needed. Task is being triggered manually. - config['impressionsBulkSize'] - ) - - while True: + config = _get_config(user_config) + metadata = get_metadata(config) + seconds = config['impressionsRefreshRate'] + storage = UWSGIImpressionStorage(get_uwsgi()) + impressions_sync_task = ImpressionsSyncTask( + ImpressionsAPI( + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), + config['apikey'], + metadata + ), + storage, + None, # Period not needed. Task is being triggered manually. + config['impressionsBulkSize'] + ) + + while True: + try: impressions_sync_task._send_impressions() #pylint: disable=protected-access for _ in xrange(0, seconds): if storage.should_flush(): storage.acknowledge_flush() break time.sleep(1) - except Exception: #pylint: disable=broad-except - _LOGGER.error('Error posting impressions') - _LOGGER.debug('Error: ', exc_info=True) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error posting impressions') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_report_events(user_config): """ @@ -140,31 +140,31 @@ def uwsgi_report_events(user_config): :param user_config: User-provided configuration. :type user_config: dict """ - try: - config = _get_config(user_config) - metadata = get_metadata(config) - seconds = config.get('eventsRefreshRate', 30) - storage = UWSGIEventStorage(get_uwsgi()) - task = EventsSyncTask( - EventsAPI( - HttpClient(1500, config.get('sdk_url'), config.get('events_url')), - config['apikey'], - metadata - ), - storage, - None, # Period not needed. Task is being triggered manually. - config['eventsBulkSize'] - ) - while True: + config = _get_config(user_config) + metadata = get_metadata(config) + seconds = config.get('eventsRefreshRate', 30) + storage = UWSGIEventStorage(get_uwsgi()) + task = EventsSyncTask( + EventsAPI( + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), + config['apikey'], + metadata + ), + storage, + None, # Period not needed. Task is being triggered manually. + config['eventsBulkSize'] + ) + while True: + try: task._send_events() #pylint: disable=protected-access for _ in xrange(0, seconds): if storage.should_flush(): storage.acknowledge_flush() break time.sleep(1) - except Exception: #pylint: disable=broad-except - _LOGGER.error('Error posting metrics') - _LOGGER.debug('Error: ', exc_info=True) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error posting metrics') + _LOGGER.debug('Error: ', exc_info=True) def uwsgi_report_telemetry(user_config): """ @@ -173,23 +173,23 @@ def uwsgi_report_telemetry(user_config): :param user_config: User-provided configuration. :type user_config: dict """ - try: - config = _get_config(user_config) - metadata = get_metadata(config) - seconds = config.get('metricsRefreshRate', 30) - storage = UWSGITelemetryStorage(get_uwsgi()) - task = TelemetrySynchronizationTask( - TelemetryAPI( - HttpClient(1500, config.get('sdk_url'), config.get('events_url')), - config['apikey'], - metadata - ), - storage, - None, # Period not needed. Task is being triggered manually. - ) - while True: + config = _get_config(user_config) + metadata = get_metadata(config) + seconds = config.get('metricsRefreshRate', 30) + storage = UWSGITelemetryStorage(get_uwsgi()) + task = TelemetrySynchronizationTask( + TelemetryAPI( + HttpClient(1500, config.get('sdk_url'), config.get('events_url')), + config['apikey'], + metadata + ), + storage, + None, # Period not needed. Task is being triggered manually. + ) + while True: + try: task._flush_telemetry() #pylint: disable=protected-access time.sleep(seconds) - except Exception: #pylint: disable=broad-except - _LOGGER.error('Error posting metrics') - _LOGGER.debug('Error: ', exc_info=True) + except Exception: #pylint: disable=broad-except + _LOGGER.error('Error posting metrics') + _LOGGER.debug('Error: ', exc_info=True) From 18ec7ce2268d0a635cc066df78b02c5dd3231465 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 11:07:11 -0300 Subject: [PATCH 31/38] update travis config --- .travis.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9a13f2be..aa179aa0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,10 +7,8 @@ services: - redis-server install: - - pip install --upgrade setuptools - - pip install redis>=2.6 - - pip install jsonpickle>=0.9.3 - - pip install uwsgi>=2.0.0 + - pip install -U setuptools pip + - python setup.py install[cpphash,redis,uwsgi] script: - - python setup.py nosetests + - python setup.py test From eeec9ae085dd9a486853dc03b965c51636608135 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 11:09:52 -0300 Subject: [PATCH 32/38] fix travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index aa179aa0..45fe0370 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ services: install: - pip install -U setuptools pip - - python setup.py install[cpphash,redis,uwsgi] + - pip install -e .[cpphash,redis,uwsgi] script: - python setup.py test From ddbb852a803613b4efbbf2c9d75025ea7ce703b2 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 11:15:00 -0300 Subject: [PATCH 33/38] remove old legacy test --- tests/engine/test_hashfns.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/engine/test_hashfns.py b/tests/engine/test_hashfns.py index 1832d31e..3ab3953c 100644 --- a/tests/engine/test_hashfns.py +++ b/tests/engine/test_hashfns.py @@ -33,25 +33,6 @@ def test_legacy_hash_ascii_data(self): assert hashfns.legacy.legacy_hash(key, seed) == hashed assert splitter.get_bucket(key, seed, splits.HashAlgorithm.LEGACY) == bucket - @pytest.mark.skipif(six.PY3, reason='Should skip this on python3.') - def test_legacy_hash_non_ascii_data(self): - """Test legacy hash function against known results.""" - splitter = splitters.Splitter() - file_name = os.path.join( - os.path.dirname(__file__), - 'files', - 'sample-data-non-alpha-numeric.jsonl' - ) - with open(file_name, 'r') as flo: - lines = flo.read().split('\n') - - for line in lines: - if line is None or line == '': - continue - seed, key, hashed, bucket = json.loads(line) - assert hashfns.legacy.legacy_hash(key, seed) == hashed - assert splitter.get_bucket(key, seed, splits.HashAlgorithm.LEGACY) == bucket - def test_murmur_hash_ascii_data(self): """Test legacy hash function against known results.""" splitter = splitters.Splitter() From d6c87a0fe289167b3adc97cc40a655c8a3f28e51 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 11:17:50 -0300 Subject: [PATCH 34/38] tests in verbose mode --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 77aa70d0..18ba96e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ test=pytest [tool:pytest] ignore_glob=./splitio/_OLD/* -addopts = --cov=splitio +addopts = --cov=splitio --verbose python_classes=*Tests [build_sphinx] From 7a7596ef8e49416a096e3ea01ea15291d634f8ec Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 11:38:39 -0300 Subject: [PATCH 35/38] add a method to build metadata headers --- splitio/api/__init__.py | 17 +++++++++++++++++ splitio/api/events.py | 8 ++------ splitio/api/impressions.py | 8 ++------ splitio/api/telemetry.py | 9 ++------- splitio/client/listener.py | 2 -- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/splitio/api/__init__.py b/splitio/api/__init__.py index a9fa4f6f..bd917328 100644 --- a/splitio/api/__init__.py +++ b/splitio/api/__init__.py @@ -12,3 +12,20 @@ def __init__(self, custom_message, status_code=None): def status_code(self): """Return HTTP status code.""" return self._status_code + + +def headers_from_metadata(sdk_metadata): + """ + Generate a dict with headers required by data-recording API endpoints. + + :param sdk_metadata: SDK Metadata object, generated at sdk initialization time. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :return: A dictionary with headers. + :rtype: dict + """ + return { + 'SplitSDKVersion': sdk_metadata.sdk_version, + 'SplitSDKMachineIP': sdk_metadata.instance_ip, + 'SplitSDKMachineName': sdk_metadata.instance_name + } diff --git a/splitio/api/events.py b/splitio/api/events.py index 59817d31..b3ac08e2 100644 --- a/splitio/api/events.py +++ b/splitio/api/events.py @@ -3,7 +3,7 @@ from future.utils import raise_from -from splitio.api import APIException +from splitio.api import APIException, headers_from_metadata from splitio.api.client import HttpClientException @@ -24,11 +24,7 @@ def __init__(self, http_client, apikey, sdk_metadata): self._logger = logging.getLogger(self.__class__.__name__) self._client = http_client self._apikey = apikey - self._metadata = { - 'SplitSDKVersion': sdk_metadata.sdk_version, - 'SplitSDKMachineIP': sdk_metadata.instance_ip, - 'SplitSDKMachineName': sdk_metadata.instance_name - } + self._metadata = headers_from_metadata(sdk_metadata) @staticmethod def _build_bulk(events): diff --git a/splitio/api/impressions.py b/splitio/api/impressions.py index 2c8ee54e..d1aa6507 100644 --- a/splitio/api/impressions.py +++ b/splitio/api/impressions.py @@ -5,7 +5,7 @@ from future.utils import raise_from -from splitio.api import APIException +from splitio.api import APIException, headers_from_metadata from splitio.api.client import HttpClientException @@ -24,11 +24,7 @@ def __init__(self, client, apikey, sdk_metadata): self._logger = logging.getLogger(self.__class__.__name__) self._client = client self._apikey = apikey - self._metadata = { - 'SplitSDKVersion': sdk_metadata.sdk_version, - 'SplitSDKMachineIP': sdk_metadata.instance_ip, - 'SplitSDKMachineName': sdk_metadata.instance_name - } + self._metadata = headers_from_metadata(sdk_metadata) @staticmethod def _build_bulk(impressions): diff --git a/splitio/api/telemetry.py b/splitio/api/telemetry.py index d9fa107f..97d747c7 100644 --- a/splitio/api/telemetry.py +++ b/splitio/api/telemetry.py @@ -4,7 +4,7 @@ import six from future.utils import raise_from -from splitio.api import APIException +from splitio.api import APIException, headers_from_metadata from splitio.api.client import HttpClientException @@ -25,12 +25,7 @@ def __init__(self, client, apikey, sdk_metadata): self._logger = logging.getLogger(self.__class__.__name__) self._client = client self._apikey = apikey - self._metadata = { - 'SplitSDKVersion': sdk_metadata.sdk_version, - 'SplitSDKMachineIP': sdk_metadata.instance_ip, - 'SplitSDKMachineName': sdk_metadata.instance_name - } - + self._metadata = headers_from_metadata(sdk_metadata) @staticmethod def _build_latencies(latencies): """ diff --git a/splitio/client/listener.py b/splitio/client/listener.py index abe78d1a..1ab61e30 100644 --- a/splitio/client/listener.py +++ b/splitio/client/listener.py @@ -51,8 +51,6 @@ def log_impression(self, impression, attributes=None): try: self.impression_listener.log_impression(data) except Exception as exc: #pylint: disable=broad-except - import traceback - traceback.print_exc() raise_from( ImpressionListenerException('Error in log_impression user\'s method is throwing exceptions'), exc From 88e2e0012f1975445fca6517be462a8abba6d30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristian=20Daniel=20Fad=C3=B3n?= Date: Wed, 24 Apr 2019 14:34:39 -0300 Subject: [PATCH 36/38] Added python 3.6 --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 45fe0370..3f9423e9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,7 @@ language: python python: - "2.7" + - "3.6" services: - redis-server From 3f59ce137285e12c160a91555652bf24d2d9c074 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 18:02:03 -0300 Subject: [PATCH 37/38] bump version & update changes.txt --- CHANGES.txt | 8 ++++++++ splitio/version.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGES.txt b/CHANGES.txt index 7782f5a4..6c395e91 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,11 @@ +8.0.0 (Apr 24, 2019) + - Full SDK Refactor/rewrite. + - New block until ready behaviour. + - Support for async destroy. + - Dynamic configs. + - Impressions not flushing on destroy bugfix. + - Removed unnecessary dependencies. + - Test suite rewritten. 7.0.1 (Mar 8, 2019) - Updated Splits refreshing rate. - Replaced exception log level to error level. diff --git a/splitio/version.py b/splitio/version.py index 775951e9..0daae8c2 100644 --- a/splitio/version.py +++ b/splitio/version.py @@ -1 +1 @@ -__version__ = '8.0.0-rc1' +__version__ = '8.0.0' From cd2f4b20e66c3c5db5650a23bff89e4c55b43bb2 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Wed, 24 Apr 2019 18:09:23 -0300 Subject: [PATCH 38/38] wait for factory to be ready in localhost tests --- tests/integration/test_client_e2e.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_client_e2e.py b/tests/integration/test_client_e2e.py index eb8dc1c6..b3e6af53 100644 --- a/tests/integration/test_client_e2e.py +++ b/tests/integration/test_client_e2e.py @@ -542,6 +542,7 @@ def test_localhost_e2e(self): """Instantiate a client with a YAML file and issue get_treatment() calls.""" filename = os.path.join(os.path.dirname(__file__), 'files', 'file2.yaml') factory = get_factory('localhost', config={'splitFile': filename}) + factory.block_until_ready() client = factory.client() assert client.get_treatment_with_config('key', 'my_feature') == ('on', '{"desc" : "this applies only to ON treatment"}') assert client.get_treatment_with_config('only_key', 'my_feature') == (