diff --git a/stdpopsim/slim_engine.py b/stdpopsim/slim_engine.py index 71a566d78..5086b00bc 100644 --- a/stdpopsim/slim_engine.py +++ b/stdpopsim/slim_engine.py @@ -40,6 +40,7 @@ import os import sys +import contextlib import copy import string import tempfile @@ -1474,48 +1475,52 @@ def simulate( tempdir = tempfile.TemporaryDirectory(prefix="stdpopsim_") ts_filename = os.path.join(tempdir.name, f"{os.urandom(3).hex()}.trees") - if slim_script: - script_filename = "stdout" - script_file = sys.stdout - else: - script_filename = os.path.join(tempdir.name, f"{os.urandom(3).hex()}.slim") - script_file = open(script_filename, "w") - - recap_epoch = slim_makescript( - script_file, - ts_filename, - demographic_model, - contig, - samples, - extended_events, - slim_scaling_factor, - slim_burn_in, - slim_rate_map, - logfile=logfile, - logfile_interval=logfile_interval, - ) - - script_file.flush() + @contextlib.contextmanager + def script_file_f(): + if run_slim: + fname = os.path.join(tempdir.name, f"{os.urandom(3).hex()}.slim") + f = open(fname, "w") + else: + fname = "stdout" + f = sys.stdout + yield f, fname + # Don't close sys.stdout. + if run_slim: + f.close() + + with script_file_f() as sf: + script_file, script_filename = sf + recap_epoch = slim_makescript( + script_file, + ts_filename, + demographic_model, + contig, + samples, + extended_events, + slim_scaling_factor, + slim_burn_in, + slim_rate_map, + logfile=logfile, + logfile_interval=logfile_interval, + ) - # Don't close sys.stdout. - if not slim_script: - script_file.close() + script_file.flush() - if not run_slim: - return None + if not run_slim: + return None - self._run_slim( - script_filename, - slim_path=slim_path, - seed=seed, - dry_run=dry_run, - verbosity=verbosity, - ) + self._run_slim( + script_filename, + slim_path=slim_path, + seed=seed, + dry_run=dry_run, + verbosity=verbosity, + ) - if dry_run: - return None + if dry_run: + return None - ts = tskit.load(ts_filename) + ts = tskit.load(ts_filename) ts = _add_dfes_to_metadata(ts, contig) if _recap_and_rescale: