Skip to content

Commit

Permalink
Add new dependence function, powerdecrease3.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahaselsteiner committed Jan 9, 2020
1 parent 445d789 commit 286cff7
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 6 deletions.
38 changes: 38 additions & 0 deletions tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,44 @@ def test_fit_lnsquare2(self):
self.assertLess(dist1.scale(0), 10)
self.assertEqual(dist1.scale.func_name, 'lnsquare2')

def test_fit_lnsquare2(self):
"""
Tests a 2D fit that includes an logarithm square dependence function.
"""

sample_hs, sample_tz, label_hs, label_tz = read_benchmark_dataset()


# Define the structure of the probabilistic model that will be fitted to the
# dataset.
dist_description_hs = {'name': 'Weibull_Exp',
'dependency': (None, None, None, None),
# Shape, Location, Scale, Shape2
'width_of_intervals': 0.5}
dist_description_tz = {'name': 'Lognormal_SigmaMu',
'dependency': (0, None, 0),
# Shape, Location, Scale
'functions': ('powerdecrease3', None, 'lnsquare2')
# Shape, Location, Scale
}

# Fit the model to the data.
fit = Fit((sample_hs, sample_tz),
(dist_description_hs, dist_description_tz))


# Check whether the logarithmic square fit worked correctly.
dist1 = fit.mul_var_dist.distributions[1]
self.assertGreater(dist1.shape.a, -0.1) # Should be about 0
self.assertLess(dist1.shape.a, 0.1) # Should be about 0
self.assertGreater(dist1.shape.b, 1.5) # Should be about 2-5
self.assertLess(dist1.shape.b, 6) # Should be about 2-10
self.assertGreater(dist1.shape.c, 0.8) # Should be about 1.1
self.assertLess(dist1.shape.c, 2) # Should be about 1.1
self.assertGreater(dist1.shape(0), 0.25) # Should be about 0.35
self.assertLess(dist1.shape(0), 0.4) # Should be about 0.35
self.assertEqual(dist1.shape.func_name, 'powerdecrease3')

def test_min_number_datapoints_for_fit(self):
"""
Tests if the minimum number of datapoints required for a fit works.
Expand Down
13 changes: 11 additions & 2 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,30 @@ def test_FunctionParam_power3(self):

def test_FunctionParam_exp3(self):
"""
tests if function exp3 calculates the correct value
tests if function exp3 calculates the correct value.
"""

test_func = FunctionParam(1, 1, 0, 'exp3')
self.assertEqual(test_func._value(0), 2)

def test_FunctionParam_lnsquare2(self):
"""
Tests if function lnsquare2 calculates the correct value
Tests if function lnsquare2 calculates the correct value.
"""

test_func = FunctionParam(1, 1, None, 'lnsquare2')
self.assertEqual(test_func._value(0), 0)
self.assertEqual(test_func.func_name, 'lnsquare2')

def test_FunctionParam_powerdecrease3(self):
"""
Tests if function powerdecrease3 calculates the correct value.
"""

test_func = FunctionParam(1, 2, 2, 'powerdecrease3')
self.assertEqual(test_func._value(0), 1.25)
self.assertEqual(test_func.func_name, 'powerdecrease3')


def test_FunctionParam_unknown(self):
"""
Expand Down
10 changes: 9 additions & 1 deletion viroconcom/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def _lnsquare2(x, a, b, c):
return np.log(a + b * np.sqrt(np.divide(x, 9.81)))


# Function that decreases with x to the power of c.
def _powerdecrease3(x, a, b, c):
return a + 1 / (x + b) ** c


# Bounds for function parameters
# 0 < a < inf
# 0 < b < inf
Expand Down Expand Up @@ -526,6 +531,7 @@ def __init__(self, samples, dist_descriptions, timeout=None):
- :power3: :math:`a + b * x^c`
- :exp3: :math:`a + b * e^{x * c}`
- :lnsquare2: :math:`ln[a + b * sqrt(x / 9.81)`
- :powerdecrease3: :math:`a + 1 / (x + b)^c`
- remark : in case of Lognormal_SigmaMu it is (sigma, None, mu)
and either number_of_intervals or width_of_intervals:
Expand Down Expand Up @@ -710,7 +716,7 @@ def _get_function(function_name):
Parameters
----------
function_name : str
Options are 'power3', 'exp3', 'lnsquare2'.
Options are 'power3', 'exp3', 'lnsquare2', 'powerdecrease3'.
Returns
-------
Expand All @@ -729,6 +735,8 @@ def _get_function(function_name):
return _exp3
elif function_name == 'lnsquare2':
return _lnsquare2
elif function_name == 'powerdecrease3':
return _powerdecrease3
elif function_name is None:
return None
else:
Expand Down
14 changes: 12 additions & 2 deletions viroconcom/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, a, b, c, func_type, wrapper=None):
:power3: :math:`a + b * x^c`
:exp3: :math:`a + b * e^{x * c}`
:lnsquare2: :math:`ln[a + b * sqrt(x / 9.81)`
:powerdecrease3: :math:`a + 1 / (x + b)^c`
wrapper : function or Wrapper
A function or a Wrapper object to wrap around the function.
The function has to be pickleable. (i.e. lambdas, clojures, etc. are not supported.)
Expand All @@ -105,6 +106,9 @@ def __init__(self, a, b, c, func_type, wrapper=None):
elif func_type == "lnsquare2":
self._func = self._lnsquare2
self.func_name = "lnsquare2"
elif func_type == "powerdecrease3":
self._func = self._powerdecrease3
self.func_name = "powerdecrease3"
else:
raise ValueError("{} is not a known kind of function.".format(func_type))

Expand Down Expand Up @@ -133,16 +137,22 @@ def _exp3(self, x):
def _lnsquare2(self, x):
return np.log(self.a + self.b * np.sqrt(np.divide(x, 9.81)))

# The 3-parameter decreasing power function (a dependence function).
def _powerdecrease3(self, x):
return self.a + 1.0 / (x + self.b) ** self.c

def _value(self, x):
return self._wrapper(self._func(x))

def __str__(self):
if self.func_name == "power3":
function_string = "" + str(self.a) + "+" + str(self.b) + "x" + "^{" + str(self.c) + "}"
function_string = "" + str(self.a) + " + " + str(self.b) + "x" + "^{" + str(self.c) + "}"
elif self.func_name == "exp3":
function_string = "" + str(self.a) + "+" + str(self.b) + "e^{" + str(self.c) + "x}"
function_string = "" + str(self.a) + " + " + str(self.b) + "e^{" + str(self.c) + "x}"
elif self.func_name == "lnsquare2":
function_string = "ln[" + str(self.a) + " + " + str(self.b) + "sqrt(x / 9.81]"
elif self.func_name == "powerdecrease3":
function_string = "" + str(self.a) + " + 1 / (x + " + str(self.b) + ")^" + str(self.c)
if isinstance(self._wrapper.func, np.ufunc):
function_string += " with _wrapper: " + str(self._wrapper)
return function_string
Expand Down
2 changes: 1 addition & 1 deletion viroconcom/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.10'
__version__ = '1.2.11'

0 comments on commit 286cff7

Please sign in to comment.