diff --git a/expan/core/early_stopping.py b/expan/core/early_stopping.py index 61834e8..a3e695b 100644 --- a/expan/core/early_stopping.py +++ b/expan/core/early_stopping.py @@ -20,6 +20,7 @@ cache_sampling_results = False sampling_results = {} # memorized sampling results +OBRIEN_FLEMING_DIVISION_FACTOR = 100 def obrien_fleming(information_fraction, alpha=0.05): """ Calculate an approximation of the O'Brien-Fleming alpha spending function. @@ -33,6 +34,19 @@ def obrien_fleming(information_fraction, alpha=0.05): :return: redistributed alpha value at the time point with the given information fraction :rtype: float """ + + alpha = alpha/OBRIEN_FLEMING_DIVISION_FACTOR + """ + The following tests needed to be adjusted to take account of this correction: + - tests/tests_core/test_early_stopping.py:: + GroupSequentialTestCases:: + test_obrien_fleming + test_group_sequential + test_group_sequential_actual_size_larger_than_estimated + - tests_core/test_experiment.py:: + StatisticalTestTestCases:: + test_group_sequential + """ return (1 - norm.cdf(norm.ppf(1 - alpha / 2) / np.sqrt(information_fraction))) * 2 diff --git a/expan/core/util.py b/expan/core/util.py index 2c9f003..bc04b4e 100644 --- a/expan/core/util.py +++ b/expan/core/util.py @@ -18,7 +18,7 @@ def __repr__(self): return self.toJson() -def find_value_by_key_with_condition(items, condition_key, condition_value, lookup_key): +def find_value_by_key_with_condition(items, condition_key, condition_value, lookup_key, tol=None): """ Find the value of lookup key where the dictionary contains condition key = condition value. :param items: list of dictionaries @@ -31,7 +31,10 @@ def find_value_by_key_with_condition(items, condition_key, condition_value, look :return: lookup value or found value for the lookup key """ - return [item[lookup_key] for item in items if item[condition_key] == condition_value][0] + if tol is None: + return [item[lookup_key] for item in items if item[condition_key] == condition_value][0] + else: + return [item[lookup_key] for item in items if abs(item[condition_key]-condition_value) < tol][0] def is_nan(obj): diff --git a/tests/tests_core/test_early_stopping.py b/tests/tests_core/test_early_stopping.py index 4211fe3..b0ea779 100644 --- a/tests/tests_core/test_early_stopping.py +++ b/tests/tests_core/test_early_stopping.py @@ -37,16 +37,16 @@ def test_obrien_fleming(self): """ Check the O'Brien-Fleming spending function.""" # Check array as input res_1 = es.obrien_fleming(np.linspace(0, 1, 5 + 1)[1:]) - expected_res = [1.17264468e-05, 1.94191300e-03, 1.13964185e-02, 2.84296308e-02, 5.00000000e-02] + expected_res = [7.1054274e-15,3.7219966e-08,7.0016877e-06,9.9583700e-05,5.0000000e-04] np.testing.assert_almost_equal(res_1, expected_res) # Check float as input res_2 = es.obrien_fleming(0.5) - self.assertAlmostEqual(res_2, 0.005574596680784305) + self.assertAlmostEqual(res_2, 8.5431190077756014e-07) # Check int as input res_3 = es.obrien_fleming(1) - self.assertAlmostEqual(res_3, 0.05) + self.assertAlmostEqual(res_3, 0.0005) def test_group_sequential(self): """ Check the group sequential function.""" @@ -60,10 +60,10 @@ def test_group_sequential(self): self.assertAlmostEqual(res.control_statistics.variance, 0.9373337542827797) self.assertAlmostEqual(res.delta, -0.15887364780635896) - value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5, 'value') - value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 97.5, 'value') - np.testing.assert_almost_equal(value025, -0.24461812530841959, decimal=5) - np.testing.assert_almost_equal(value975, -0.07312917030429833, decimal=5) + value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5) + value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5) + np.testing.assert_almost_equal(value025, -0.31130760395377599, decimal=5) + np.testing.assert_almost_equal(value975, -0.0064396916589367081, decimal=5) self.assertAlmostEqual(res.p, 0.0002863669955157941) self.assertAlmostEqual(res.statistical_power, 0.9529152504960496) @@ -75,10 +75,10 @@ def test_group_sequential_actual_size_larger_than_estimated(self): """ res = es.group_sequential(self.rand_s1, self.rand_s2, estimated_sample_size=100) - value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5, 'value') - value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 97.5, 'value') - np.testing.assert_almost_equal (value025, -0.24461812530841959, decimal=5) - np.testing.assert_almost_equal (value975, -0.07312917030429833, decimal=5) + value025 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', tol=1e-5) + value975 = find_value_by_key_with_condition(res.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', tol=1e-5) + np.testing.assert_almost_equal (value025, -0.31130760395377599, decimal=5) + np.testing.assert_almost_equal (value975, -0.00643969165893670, decimal=5) class BayesFactorTestCases(EarlyStoppingTestCase): diff --git a/tests/tests_core/test_experiment.py b/tests/tests_core/test_experiment.py index 65f6199..7fffaf6 100644 --- a/tests/tests_core/test_experiment.py +++ b/tests/tests_core/test_experiment.py @@ -8,6 +8,7 @@ from expan.core.results import CombinedTestStatistics from expan.core.statistical_test import * from expan.core.experiment import Experiment +import expan.core.early_stopping as es from expan.core.util import generate_random_data, find_value_by_key_with_condition @@ -129,10 +130,10 @@ def test_group_sequential(self): self.assertAlmostEqual(res.result.delta, 0.033053, ndecimals) - lower_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 2.5, 'value') - upper_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 97.5, 'value') - self.assertAlmostEqual(lower_bound_ci, -0.007135, ndecimals) - self.assertAlmostEqual(upper_bound_ci, 0.073240, ndecimals) + lower_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5) + upper_bound_ci = find_value_by_key_with_condition(res.result.confidence_interval, 'percentile', 100-2.5/es.OBRIEN_FLEMING_DIVISION_FACTOR, 'value', 1e-5) + self.assertAlmostEqual(lower_bound_ci, -0.0383319, ndecimals) + self.assertAlmostEqual(upper_bound_ci, 0.104437, ndecimals) self.assertEqual(res.result.treatment_statistics.sample_size, 6108) self.assertEqual(res.result.control_statistics.sample_size, 3892)