Skip to content

Commit

Permalink
Merge pull request #61 from mcveanlab/fix-all-models
Browse files Browse the repository at this point in the history
Refine all_models to only return from within stdpopsim.
  • Loading branch information
jeromekelleher committed Apr 13, 2019
2 parents c5ad81a + b966eee commit e44b759
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
11 changes: 9 additions & 2 deletions stdpopsim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Common infrastructure for specifying demographic models.
"""
import sys
import inspect

import msprime
import numpy as np
Expand Down Expand Up @@ -172,6 +173,12 @@ def verify_equal(self, other, rtol=DEFAULT_RTOL, atol=DEFAULT_ATOL):

def all_models():
"""
Returns the list of all Model classes that have been defined.
Returns the list of all Model classes that are defined within the stdpopsim
module.
"""
return [cls() for cls in Model.__subclasses__()]
ret = []
for cls in Model.__subclasses__():
mod = inspect.getmodule(cls).__name__
if mod.startswith("stdpopsim"):
ret.append(cls())
return ret
12 changes: 12 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def f(time=1, population=1, proportion=1):
models.verify_demographic_events_equal([a], [b], 1)


class DummyModel(models.Model):
"""
Dummy subclass to make sure we're filtering models correctly.
"""


class TestAllModels(unittest.TestCase):
"""
Tests that we can get all known simulation models.
Expand All @@ -237,12 +243,18 @@ def test_all_instances(self):
for model in models.all_models():
self.assertIsInstance(model, models.Model)

def test_filtering_outside_classes(self):
for model in models.all_models():
self.assertNotIsInstance(model, DummyModel)


class TestModelsEqual(unittest.TestCase):
"""
Tests Model object equality comparison.
"""
def test_known_models(self):
# This assumes that every model should be equal to itself and should be
# different to every other model.
known_models = models.all_models()
n = len(known_models)
for j in range(n):
Expand Down

0 comments on commit e44b759

Please sign in to comment.