Skip to content

Commit

Permalink
Merge b34c6c5 into 9117b1e
Browse files Browse the repository at this point in the history
  • Loading branch information
ahaselsteiner committed Jan 9, 2020
2 parents 9117b1e + b34c6c5 commit d64ce7e
Show file tree
Hide file tree
Showing 6 changed files with 8,731 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/fit_exponentiated_weibull_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from viroconcom.contours import HighestDensityContour


def read_dataset(path='examples/datasets/A.txt'):
def read_benchmark_dataset(path='examples/datasets/A.txt'):
"""
Reads a datasets provided for the environmental contour benchmark.
Parameters
Expand Down Expand Up @@ -45,7 +45,7 @@ def read_dataset(path='examples/datasets/A.txt'):
y = np.asarray(y)
return (x, y, x_label, y_label)

sample_hs, sample_tz, label_hs, label_tz = read_dataset()
sample_hs, sample_tz, label_hs, label_tz = read_benchmark_dataset()

# Define the structure of the probabilistic model that will be fitted to the
# dataset.
Expand Down
82 changes: 81 additions & 1 deletion tests/test_fitting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
import unittest

import csv
import numpy as np

from viroconcom.fitting import Fit


def read_benchmark_dataset(path='tests/testfiles/1year_dataset_A.txt'):
"""
Reads a datasets provided for the environmental contour benchmark.
Parameters
----------
path : string
Path to dataset including the file name, defaults to 'examples/datasets/A.txt'
Returns
-------
x : ndarray of doubles
Observations of the environmental variable 1.
y : ndarray of doubles
Observations of the environmental variable 2.
x_label : str
Label of the environmantal variable 1.
y_label : str
Label of the environmental variable 2.
"""

x = list()
y = list()
x_label = None
y_label = None
with open(path, newline='') as csv_file:
reader = csv.reader(csv_file, delimiter=';')
idx = 0
for row in reader:
if idx == 0:
x_label = row[1][
1:] # Ignore first char (is a white space).
y_label = row[2][
1:] # Ignore first char (is a white space).
if idx > 0: # Ignore the header
x.append(float(row[1]))
y.append(float(row[2]))
idx = idx + 1

x = np.asarray(x)
y = np.asarray(y)
return (x, y, x_label, y_label)


class FittingTest(unittest.TestCase):

def test_2d_fit(self):
Expand Down Expand Up @@ -101,6 +144,43 @@ def test_2d_exponentiated_wbl_fit(self):
self.assertGreater(dist0.shape2(0), 0.5) # Should be about 1.
self.assertLess(dist0.shape2(0), 2)


def test_fit_lnsquare2(self):
"""
Tests a 2D fit that includes an logarithm square dependence function.
"""

sample_hs, sample_tz, label_hs, label_tz = read_benchmark_dataset()


# Define the structure of the probabilistic model that will be fitted to the
# dataset.
dist_description_hs = {'name': 'Weibull_Exp',
'dependency': (None, None, None, None),
# Shape, Location, Scale, Shape2
'width_of_intervals': 0.5}
dist_description_tz = {'name': 'Lognormal_SigmaMu',
'dependency': (0, None, 0),
# Shape, Location, Scale
'functions': ('exp3', None, 'lnsquare2')
# Shape, Location, Scale
}

# Fit the model to the data.
fit = Fit((sample_hs, sample_tz),
(dist_description_hs, dist_description_tz))


# Check whether the logarithmic square fit worked correctly.
dist1 = fit.mul_var_dist.distributions[1]
self.assertGreater(dist1.scale.a, 1) # Should be about 1-5
self.assertLess(dist1.scale.a, 5) # Should be about 1-5
self.assertGreater(dist1.scale.b, 2) # Should be about 2-10
self.assertLess(dist1.scale.b, 10) # Should be about 2-10
self.assertGreater(dist1.scale(0), 0.1)
self.assertLess(dist1.scale(0), 10)


def test_multi_processing(selfs):
"""
2-d Fit with multiprocessing (specified by setting a value for timeout)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def test_FunctionParam_exp3(self):
test_func = FunctionParam(1, 1, 0, 'exp3')
self.assertEqual(test_func._value(0), 2)

def test_FunctionParam_lnsquare2(self):
"""
Tests if function lnsquare2 calculates the correct value
"""

test_func = FunctionParam(1, 1, None, 'lnsquare2')
self.assertEqual(test_func._value(0), 0)


def test_FunctionParam_unknown(self):
"""
Expand Down

0 comments on commit d64ce7e

Please sign in to comment.