Skip to content

Commit

Permalink
Merge a6e2523 into 5fa3dd1
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-mcdaid-zalando committed Sep 18, 2019
2 parents 5fa3dd1 + a6e2523 commit 96bdb00
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
14 changes: 14 additions & 0 deletions expan/core/early_stopping.py
Expand Up @@ -20,6 +20,7 @@
cache_sampling_results = False
sampling_results = {} # memorized sampling results

OBRIEN_FLEMING_DIVISION_FACTOR = 100

def obrien_fleming(information_fraction, alpha=0.05):
""" Calculate an approximation of the O'Brien-Fleming alpha spending function.
Expand All @@ -33,6 +34,19 @@ def obrien_fleming(information_fraction, alpha=0.05):
:return: redistributed alpha value at the time point with the given information fraction
:rtype: float
"""

alpha = alpha/OBRIEN_FLEMING_DIVISION_FACTOR
"""
The following tests needed to be adjusted to take account of this correction:
- tests/tests_core/test_early_stopping.py::
GroupSequentialTestCases::
test_obrien_fleming
test_group_sequential
test_group_sequential_actual_size_larger_than_estimated
- tests_core/test_experiment.py::
StatisticalTestTestCases::
test_group_sequential
"""
return (1 - norm.cdf(norm.ppf(1 - alpha / 2) / np.sqrt(information_fraction))) * 2


Expand Down
7 changes: 5 additions & 2 deletions expan/core/util.py
Expand Up @@ -18,7 +18,7 @@ def __repr__(self):
return self.toJson()


def find_value_by_key_with_condition(items, condition_key, condition_value, lookup_key):
def find_value_by_key_with_condition(items, condition_key, condition_value, lookup_key, tol=None):
""" Find the value of lookup key where the dictionary contains condition key = condition value.
:param items: list of dictionaries
Expand All @@ -31,7 +31,10 @@ def find_value_by_key_with_condition(items, condition_key, condition_value, look
:return: lookup value or found value for the lookup key
"""
return [item[lookup_key] for item in items if item[condition_key] == condition_value][0]
if tol is None:
return [item[lookup_key] for item in items if item[condition_key] == condition_value][0]
else:
return [item[lookup_key] for item in items if abs(item[condition_key]-condition_value) < tol][0]


def is_nan(obj):
Expand Down
22 changes: 11 additions & 11 deletions tests/tests_core/test_early_stopping.py
Expand Up @@ -37,16 +37,16 @@ def test_obrien_fleming(self):
""" Check the O'Brien-Fleming spending function."""
# Check array as input
res_1 = es.obrien_fleming(np.linspace(0, 1, 5 + 1)[1:])
expected_res = [1.17264468e-05, 1.94191300e-03, 1.13964185e-02, 2.84296308e-02, 5.00000000e-02]
expected_res = [7.1054274e-15,3.7219966e-08,7.0016877e-06,9.9583700e-05,5.0000000e-04]
np.testing.assert_almost_equal(res_1, expected_res)

# Check float as input
res_2 = es.obrien_fleming(0.5)
self.assertAlmostEqual(res_2, 0.005574596680784305)
self.assertAlmostEqual(res_2, 8.5431190077756014e-07)

# Check int as input
res_3 = es.obrien_fleming(1)
self.assertAlmostEqual(res_3, 0.05)
self.assertAlmostEqual(res_3, 0.0005)

def test_group_sequential(self):
""" Check the group sequential function."""
Expand All @@ -60,10 +60,10 @@ def test_group_sequential(self):
self.assertAlmostEqual(res.control_statistics.variance, 0.9373337542827797)

self.assertAlmostEqual(res.delta, -0.15887364780635896)
value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5, 'value')
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 97.5, 'value')
np.testing.assert_almost_equal(value025, -0.24461812530841959, decimal=5)
np.testing.assert_almost_equal(value975, -0.07312917030429833, decimal=5)
value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
np.testing.assert_almost_equal(value025, -0.31130760395377599, decimal=5)
np.testing.assert_almost_equal(value975, -0.0064396916589367081, decimal=5)

self.assertAlmostEqual(res.p, 0.0002863669955157941)
self.assertAlmostEqual(res.statistical_power, 0.9529152504960496)
Expand All @@ -75,10 +75,10 @@ def test_group_sequential_actual_size_larger_than_estimated(self):
"""
res = es.group_sequential(self.rand_s1, self.rand_s2, estimated_sample_size=100)

value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5, 'value')
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 97.5, 'value')
np.testing.assert_almost_equal (value025, -0.24461812530841959, decimal=5)
np.testing.assert_almost_equal (value975, -0.07312917030429833, decimal=5)
value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', tol=1e-5)
value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', tol=1e-5)
np.testing.assert_almost_equal (value025, -0.31130760395377599, decimal=5)
np.testing.assert_almost_equal (value975, -0.00643969165893670, decimal=5)


class BayesFactorTestCases(EarlyStoppingTestCase):
Expand Down
9 changes: 5 additions & 4 deletions tests/tests_core/test_experiment.py
Expand Up @@ -8,6 +8,7 @@
from expan.core.results import CombinedTestStatistics
from expan.core.statistical_test import *
from expan.core.experiment import Experiment
import expan.core.early_stopping as es
from expan.core.util import generate_random_data, find_value_by_key_with_condition


Expand Down Expand Up @@ -129,10 +130,10 @@ def test_group_sequential(self):

self.assertAlmostEqual(res.result.delta, 0.033053, ndecimals)

lower_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 2.5, 'value')
upper_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 97.5, 'value')
self.assertAlmostEqual(lower_bound_ci, -0.007135, ndecimals)
self.assertAlmostEqual(upper_bound_ci, 0.073240, ndecimals)
lower_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
upper_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5)
self.assertAlmostEqual(lower_bound_ci, -0.0383319, ndecimals)
self.assertAlmostEqual(upper_bound_ci, 0.104437, ndecimals)

self.assertEqual(res.result.treatment_statistics.sample_size, 6108)
self.assertEqual(res.result.control_statistics.sample_size, 3892)
Expand Down

0 comments on commit 96bdb00

Please sign in to comment.