Skip to content

Commit

Permalink
Merge pull request #736 from awohns/msprime_kwargs
Browse files Browse the repository at this point in the history
msprime kwargs
  • Loading branch information
grahamgower committed Jan 25, 2021
2 parents 6774997 + e5a4ac4 commit c396794
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 11 deletions.
27 changes: 20 additions & 7 deletions stdpopsim/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ class Engine:

def simulate(
self,
demographic_model=None,
contig=None,
samples=None,
demographic_model,
contig,
samples,
*,
seed=None,
dry_run=False,
):
"""
Simulates the model for the specified contig and samples.
Simulates the model for the specified contig and samples. ``demographic_model``,
``contig``, and ``samples`` must be specified.
:param demographic_model: The demographic model to simulate.
:type demographic_model: :class:`.DemographicModel`
Expand Down Expand Up @@ -126,13 +128,15 @@ class _MsprimeEngine(Engine):

def simulate(
self,
demographic_model=None,
contig=None,
samples=None,
demographic_model,
contig,
samples,
*,
seed=None,
msprime_model=None,
msprime_change_model=None,
dry_run=False,
**kwargs,
):
"""
Simulate the demographic model using msprime.
Expand All @@ -149,6 +153,7 @@ def simulate(
:param dry_run: If True, ``end_time=0`` is passed to :meth:`msprime.simulate()`
to initialise the simulation and then immediately return.
:type dry_run: bool
:param \\**kwargs: Further arguments passed to :meth:`msprime.simulate()`
"""
if msprime_model is None:
msprime_model = self.supported_models[0]
Expand All @@ -169,6 +174,13 @@ def simulate(
self.citations.extend(self.model_citations[model])
demographic_events.sort(key=lambda x: x.time)

if "random_seed" in kwargs.keys():
if seed is None:
seed = kwargs["random_seed"]
del kwargs["random_seed"]
else:
raise ValueError("Cannot set both seed and random_seed")

ts = msprime.simulate(
samples=samples,
recombination_map=contig.recombination_map,
Expand All @@ -179,6 +191,7 @@ def simulate(
random_seed=seed,
model=msprime_model,
end_time=0 if dry_run else None,
**kwargs,
)

if contig.inclusion_mask is not None:
Expand Down
7 changes: 4 additions & 3 deletions stdpopsim/slim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,10 @@ def get_version(self):

def simulate(
self,
demographic_model=None,
contig=None,
samples=None,
demographic_model,
contig,
samples,
*,
seed=None,
mutation_types=None,
extended_events=None,
Expand Down
35 changes: 34 additions & 1 deletion tests/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import stdpopsim
import msprime


class TestEngineAPI(unittest.TestCase):
Expand Down Expand Up @@ -49,7 +50,8 @@ def test_get_engine(self):

def test_abstract_base_class(self):
e = stdpopsim.Engine()
self.assertRaises(NotImplementedError, e.simulate)
with self.assertRaises(NotImplementedError):
e.simulate(None, None, None)
self.assertRaises(NotImplementedError, e.get_version)


Expand All @@ -68,3 +70,34 @@ def test_simulate_nonexistent_param(self):
engine.simulate(**good_kwargs)
with self.assertRaises(TypeError):
engine.simulate(**bad_kwargs)

def test_required_params(self):
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AshkSub_7G19")
contig = (species.get_contig("chr1"),)
for engine in stdpopsim.all_engines():
with self.assertRaises(TypeError):
engine.simulate(model, contig)

def test_msprime_kwargs(self):
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AshkSub_7G19")
contig = species.get_contig("chr22", length_multiplier=0.01)
samples = model.get_samples(10)
engine = stdpopsim.get_engine("msprime")
sim_arg = engine.simulate(
model, contig, samples, record_full_arg=True, random_seed=1
)
assert any(msprime.NODE_IS_RE_EVENT == sim_arg.tables.nodes.flags)

def test_msprime_seed(self):
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AshkSub_7G19")
contig = species.get_contig("chr22", length_multiplier=0.01)
samples = model.get_samples(10)
engine = stdpopsim.get_engine("msprime")
with self.assertRaises(ValueError):
engine.simulate(model, contig, samples, seed=1, random_seed=1)
sim_seed = engine.simulate(model, contig, samples, seed=1)
sim_random_seed = engine.simulate(model, contig, samples, random_seed=1)
self.assertEquals(sim_seed.tables.edges, sim_random_seed.tables.edges)

0 comments on commit c396794

Please sign in to comment.