Skip to content

Commit

Permalink
Enable users to control the required minimum of datapoints for a fit.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahaselsteiner committed Jan 9, 2020
1 parent b62b8b3 commit 445d789
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 8 deletions.
49 changes: 49 additions & 0 deletions tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,55 @@ def test_fit_lnsquare2(self):
self.assertLess(dist1.scale(0), 10)
self.assertEqual(dist1.scale.func_name, 'lnsquare2')

def test_min_number_datapoints_for_fit(self):
"""
Tests if the minimum number of datapoints required for a fit works.
"""

sample_hs, sample_tz, label_hs, label_tz = read_benchmark_dataset()

# Define the structure of the probabilistic model that will be fitted to the
# dataset.
dist_description_hs = {'name': 'Weibull_Exp',
'dependency': (None, None, None, None),
# Shape, Location, Scale, Shape2
'width_of_intervals': 0.5}
dist_description_tz = {'name': 'Lognormal_SigmaMu',
'dependency': (0, None, 0),
# Shape, Location, Scale
'functions': ('exp3', None, 'lnsquare2'),
# Shape, Location, Scale
'min_datapoints_for_fit': 10
}

# Fit the model to the data.
fit = Fit((sample_hs, sample_tz),
(dist_description_hs, dist_description_tz))

# Check whether the logarithmic square fit worked correctly.
dist1 = fit.mul_var_dist.distributions[1]
a_min_10 = dist1.scale.a

# Now require more datapoints for a fit.
dist_description_tz = {'name': 'Lognormal_SigmaMu',
'dependency': (0, None, 0),
# Shape, Location, Scale
'functions': ('exp3', None, 'lnsquare2'),
# Shape, Location, Scale
'min_datapoints_for_fit': 500
}

# Fit the model to the data.
fit = Fit((sample_hs, sample_tz),
(dist_description_hs, dist_description_tz))

# Check whether the logarithmic square fit worked correctly.
dist1 = fit.mul_var_dist.distributions[1]
a_min_500 = dist1.scale.a

# Because in case 2 fewer bins have been used we should get different
# coefficients for the dependence function.
self.assertNotEqual(a_min_10, a_min_500)

def test_multi_processing(selfs):
"""
Expand Down
19 changes: 12 additions & 7 deletions viroconcom/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,8 @@ def _append_params(name, param_values, dependency, index, sample):

@staticmethod
def _get_fitting_values(sample, samples, name, dependency, index,
number_of_intervals=None, bin_width=None):
number_of_intervals=None, bin_width=None,
min_datapoints_for_fit=20):
"""
Returns values for fitting.
Expand All @@ -799,6 +800,8 @@ def _get_fitting_values(sample, samples, name, dependency, index,
Order : (shape, loc, scale) (i.e. 0 -> shape).
number_of_intervals : int
Number of distributions used to fit shape, loc, scale.
min_datapoints_for_fit : int
Minimum number of datapoints required to perform the fit.
Notes
-----
For that case that number_of_intervals and also bin_width is given the parameter
Expand All @@ -824,7 +827,6 @@ def _get_fitting_values(sample, samples, name, dependency, index,
RuntimeError
If there was not enough data and the number of intervals was less than three.
"""
MIN_DATA_POINTS_FOR_FIT = 10

# Compute intervals.
if number_of_intervals:
Expand Down Expand Up @@ -863,7 +865,7 @@ def _get_fitting_values(sample, samples, name, dependency, index,
mask = ((sorted_samples[:, 1] >= step - 0.5 * interval_width) &
(sorted_samples[:, 1] < step + 0.5 * interval_width))
samples_in_interval = sorted_samples[mask, 0]
if len(samples_in_interval) >= MIN_DATA_POINTS_FOR_FIT:
if len(samples_in_interval) >= min_datapoints_for_fit:
try:
# Fit distribution to selected data.
basic_fit = Fit._append_params(
Expand All @@ -882,12 +884,12 @@ def _get_fitting_values(sample, samples, name, dependency, index,
# the step is deleted.
deleted_centers.append(i) # Add index of unused center.
warnings.warn(
"'Due to the restriction of MIN_DATA_POINTS_FOR_FIT='{}' "
"'Due to the restriction of min_datapoints_for_fit='{}' "
"there is not enough data (n='{}') for the interval "
"centered at '{}' in dimension '{}'. No distribution will "
"be fitted to this interval. Consider adjusting your "
"intervals."
.format(MIN_DATA_POINTS_FOR_FIT,
.format(min_datapoints_for_fit,
len(samples_in_interval),
step,
dependency[index]),
Expand Down Expand Up @@ -949,6 +951,7 @@ def _get_distribution(self, dimension, samples, **kwargs):
functions = kwargs.get('functions', ('polynomial', )*len(dependency))
list_number_of_intervals = kwargs.get('list_number_of_intervals')
list_width_of_intervals = kwargs.get('list_width_of_intervals')
min_datapoints_for_fit = kwargs.get('min_datapoints_for_fit', 20)

# Fit inspection data for current dimension
fit_inspection_data = FitInspectionData()
Expand Down Expand Up @@ -1008,13 +1011,15 @@ def _get_distribution(self, dimension, samples, **kwargs):
interval_centers, dist_values, param_values, multiple_basic_fit = \
Fit._get_fitting_values(
sample, samples, name, dependency, index,
number_of_intervals=list_number_of_intervals[dependency[index]])
number_of_intervals=list_number_of_intervals[dependency[index]],
min_datapoints_for_fit=min_datapoints_for_fit)
# If a the (constant) width of the intervals is given.
elif list_width_of_intervals[dependency[index]]:
interval_centers, dist_values, param_values, multiple_basic_fit = \
Fit._get_fitting_values(
sample, samples, name, dependency, index,
bin_width=list_width_of_intervals[dependency[index]])
bin_width=list_width_of_intervals[dependency[index]],
min_datapoints_for_fit=min_datapoints_for_fit)

for i in range(index, len(functions)):
# Check if the other parameters have the same dependency
Expand Down
2 changes: 1 addition & 1 deletion viroconcom/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.9'
__version__ = '1.2.10'

0 comments on commit 445d789

Please sign in to comment.