Skip to content

Commit

Permalink
Merge pull request #56 from mcveanlab/check_equal
Browse files Browse the repository at this point in the history
Method to raise an exception if models are different. We got two approvals, I'm merging the PR.
  • Loading branch information
andrewkern committed Apr 10, 2019
2 parents f4a4953 + 16ed51a commit c5ad81a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 30 deletions.
106 changes: 76 additions & 30 deletions stdpopsim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,41 @@
DEFAULT_RTOL = 1e-08


class UnequalModelsError(Exception):
"""
Exception raised models by verify_equal to indicate that models are
not sufficiently close.
"""


def population_configurations_equal(
pop_configs1, pop_configs2, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):
"""
Returns True if the specified lists of msprime PopulationConfiguration
objects are equal to the specified tolerances.
See the :func:`.verify_population_configurations_equal` function for
details on the assumptions made about the objects.
"""
try:
verify_population_configurations_equal(
pop_configs1, pop_configs2, rtol=rtol, atol=atol)
return True
except UnequalModelsError:
return False


def verify_population_configurations_equal(
pop_configs1, pop_configs2, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):
"""
Checks if the specified lists of msprime PopulationConfiguration
objects are equal to the specified tolerances and raises an UnequalModelsError
otherwise.
We make some assumptions here to ensure that the models we specify
are well-defined: (1) The sample size is not set for PopulationConfigurations
(2) the initial_size is defined.
(2) the initial_size is defined. If these assumptions are violated a
ValueError is raised.
"""
for pc1, pc2 in zip(pop_configs1, pop_configs2):
if pc1.sample_size is not None or pc2.sample_size is not None:
Expand All @@ -30,17 +56,16 @@ def population_configurations_equal(
if pc1.initial_size is None or pc2.initial_size is None:
raise ValueError(
"Models defined in stdpopsim must set the initial_size")
sample_size1 = np.array([pc.sample_size for pc in pop_configs1])
sample_size2 = np.array([pc.sample_size for pc in pop_configs2])
if len(pop_configs1) != len(pop_configs2):
raise UnequalModelsError("Different numbers of populations")
initial_size1 = np.array([pc.initial_size for pc in pop_configs1])
initial_size2 = np.array([pc.initial_size for pc in pop_configs2])
if not np.allclose(initial_size1, initial_size2, rtol=rtol, atol=atol):
raise UnequalModelsError("Initial sizes differ")
growth_rate1 = np.array([pc.growth_rate for pc in pop_configs1])
growth_rate2 = np.array([pc.growth_rate for pc in pop_configs2])
return (
len(pop_configs1) == len(pop_configs2)
and np.all(sample_size1 == sample_size2)
and np.allclose(initial_size1, initial_size2, rtol=rtol, atol=atol)
and np.allclose(growth_rate1, growth_rate2, rtol=rtol, atol=atol))
if not np.allclose(growth_rate1, growth_rate2, rtol=rtol, atol=atol):
raise UnequalModelsError("Growth rates differ")


def demographic_events_equal(
Expand All @@ -49,24 +74,39 @@ def demographic_events_equal(
Returns True if the specified list of msprime DemographicEvent objects are equal
to the specified tolerances.
"""
try:
verify_demographic_events_equal(
events1, events2, num_populations, rtol=rtol, atol=atol)
return True
except UnequalModelsError:
return False


def verify_demographic_events_equal(
events1, events2, num_populations, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):
"""
Checks if the specified list of msprime DemographicEvent objects are equal
to the specified tolerances and raises a UnequalModelsError otherwise.
"""
# Get the low-level dictionary representations of the events.
dicts1 = [event.get_ll_representation(num_populations) for event in events1]
dicts2 = [event.get_ll_representation(num_populations) for event in events2]
if len(dicts1) != len(dicts2):
return False
raise UnequalModelsError("Different numbers of demographic events")
for d1, d2 in zip(dicts1, dicts2):
if set(d1.keys()) != set(d2.keys()):
return False
raise UnequalModelsError("Different types of demographic events")
for key in d1.keys():
value1 = d1[key]
value2 = d2[key]
if isinstance(value1, float):
if not np.isclose(value1, value2, rtol=rtol, atol=atol):
return False
raise UnequalModelsError("Event {} mismatch: {} != {}".format(
key, value1, value2))
else:
if value1 != value2:
return False
return True
raise UnequalModelsError("Event {} mismatch: {} != {}".format(
key, value1, value2))


class Model(object):
Expand Down Expand Up @@ -104,24 +144,30 @@ def equals(self, other, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):
We use the 'equals' method here rather than the equality operator
because we need to be able to specifiy the numerical tolerances.
"""
ret = False
try:
mm1 = np.array(self.migration_matrix)
mm2 = np.array(other.migration_matrix)
ret = (
mm1.shape == mm2.shape
and np.allclose(mm1, mm2, rtol=rtol, atol=atol)
and population_configurations_equal(
self.population_configurations, other.population_configurations,
rtol=rtol, atol=atol)
and demographic_events_equal(
self.demographic_events, other.demographic_events,
len(self.population_configurations),
rtol=rtol, atol=atol))
except AttributeError:
# Anything that's not duck-typeable to a Model is considered not equal
pass
return ret
self.verify_equal(other, rtol=rtol, atol=atol)
return True
except (UnequalModelsError, AttributeError):
return False

def verify_equal(self, other, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):
"""
Equivalent to the :func:`.equals` method, but raises a UnequalModelsError if the
models are not equal rather than returning False.
"""
mm1 = np.array(self.migration_matrix)
mm2 = np.array(other.migration_matrix)
if mm1.shape != mm2.shape:
raise UnequalModelsError("Migration matrices different shapes")
if not np.allclose(mm1, mm2, rtol=rtol, atol=atol):
raise UnequalModelsError("Migration matrices differ")
verify_population_configurations_equal(
self.population_configurations, other.population_configurations,
rtol=rtol, atol=atol)
verify_demographic_events_equal(
self.demographic_events, other.demographic_events,
len(self.population_configurations),
rtol=rtol, atol=atol)


def all_models():
Expand Down
32 changes: 32 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_different_lengths(self):
self.assertFalse(models.population_configurations_equal([pc, pc], [pc]))
self.assertFalse(models.population_configurations_equal([pc], [pc, pc]))
self.assertTrue(models.population_configurations_equal([pc], [pc]))
with self.assertRaises(models.UnequalModelsError):
models.verify_population_configurations_equal([pc], [pc, pc])

def test_initial_sizes(self):
test_sizes = [
Expand All @@ -62,6 +64,8 @@ def test_initial_sizes(self):
self.assertFalse(models.population_configurations_equal(pc_list2, pc_list1))
self.assertTrue(models.population_configurations_equal(pc_list1, pc_list1))
self.assertTrue(models.population_configurations_equal(pc_list2, pc_list2))
with self.assertRaises(models.UnequalModelsError):
models.verify_population_configurations_equal(pc_list2, pc_list1)

def test_growth_rates(self):
test_rates = [
Expand All @@ -81,6 +85,8 @@ def test_growth_rates(self):
self.assertFalse(models.population_configurations_equal(pc_list2, pc_list1))
self.assertTrue(models.population_configurations_equal(pc_list1, pc_list1))
self.assertTrue(models.population_configurations_equal(pc_list2, pc_list2))
with self.assertRaises(models.UnequalModelsError):
models.verify_population_configurations_equal(pc_list2, pc_list1)


class TestDemographicEventsEqual(unittest.TestCase):
Expand All @@ -96,6 +102,8 @@ def test_different_lengths(self):
self.assertFalse(models.demographic_events_equal([], events[:1], 1))
self.assertFalse(models.demographic_events_equal(events, [], 1))
self.assertFalse(models.demographic_events_equal([], events, 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([], events, 1)

def test_different_times(self):
n = 10
Expand All @@ -108,6 +116,10 @@ def test_different_times(self):
for j in range(1, n):
self.assertFalse(models.demographic_events_equal(e1[:j], e2[:j], 1))
self.assertFalse(models.demographic_events_equal(e2[:j], e1[:j], 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal(e1[:j], e2[:j], 1)
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal(e2[:j], e1[:j], 1)

def test_different_types(self):
events = [
Expand All @@ -120,6 +132,10 @@ def test_different_types(self):
self.assertFalse(models.demographic_events_equal([b], [a], 1))
self.assertTrue(models.demographic_events_equal([a], [a], 1))
self.assertTrue(models.demographic_events_equal([b], [b], 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([b], [a], 1)
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([a], [b], 1)

def test_population_parameters_change(self):

Expand All @@ -139,6 +155,10 @@ def f(time=1, initial_size=1, growth_rate=None, population=None):
self.assertFalse(models.demographic_events_equal([b], [a], 1))
self.assertTrue(models.demographic_events_equal([a], [a], 1))
self.assertTrue(models.demographic_events_equal([b], [b], 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([b], [a], 1)
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([a], [b], 1)

def test_migration_rate_change(self):

Expand All @@ -157,6 +177,10 @@ def f(time=1, rate=1, matrix_index=None):
self.assertFalse(models.demographic_events_equal([b], [a], 1))
self.assertTrue(models.demographic_events_equal([a], [a], 1))
self.assertTrue(models.demographic_events_equal([b], [b], 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([b], [a], 1)
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([a], [b], 1)

def test_mass_migration(self):

Expand All @@ -175,6 +199,10 @@ def f(time=1, source=1, dest=1, proportion=1):
self.assertFalse(models.demographic_events_equal([b], [a], 1))
self.assertTrue(models.demographic_events_equal([a], [a], 1))
self.assertTrue(models.demographic_events_equal([b], [b], 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([b], [a], 1)
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([a], [b], 1)

def test_simple_bottleneck(self):

Expand All @@ -192,6 +220,10 @@ def f(time=1, population=1, proportion=1):
self.assertFalse(models.demographic_events_equal([b], [a], 1))
self.assertTrue(models.demographic_events_equal([a], [a], 1))
self.assertTrue(models.demographic_events_equal([b], [b], 1))
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([b], [a], 1)
with self.assertRaises(models.UnequalModelsError):
models.verify_demographic_events_equal([a], [b], 1)


class TestAllModels(unittest.TestCase):
Expand Down

0 comments on commit c5ad81a

Please sign in to comment.