Skip to content

Commit

Permalink
Add provenance output for CLI
Browse files Browse the repository at this point in the history
Closes #730
  • Loading branch information
jeromekelleher authored and mergify[bot] committed Oct 13, 2022
1 parent baf5dd5 commit 008f811
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
49 changes: 49 additions & 0 deletions tests/test_cli.py
Expand Up @@ -20,6 +20,7 @@
Tests for the tsinfer CLI.
"""
import io
import json
import os.path
import pathlib
import sys
Expand Down Expand Up @@ -280,6 +281,54 @@ def test_augment_ancestors(self):
)


class TestProvenance(TestCli):
"""
Tests that we get provenance in the output trees
"""

# Need to mock out setup_logging here or we spew logging to the console
# in later tests.
@mock.patch("tsinfer.cli.setup_logging")
def run_command(self, command, mock_setup_logging):
stdout, stderr = capture_output(cli.tsinfer_main, command)
assert stderr == ""
assert stdout == ""

def verify_ts_provenance(self, treefile):
ts = tskit.load(treefile)
prov = json.loads(ts.provenance(-1).record)
# Getting actual values out of the JSON is problematic here because
# we're getting the pytest command line.
assert isinstance(prov["parameters"]["command"], str)
assert isinstance(prov["parameters"]["args"], list)

def test_infer(self):
output_trees = os.path.join(self.tempdir.name, "output.trees")
self.run_command(["infer", self.sample_file, "-O", output_trees])
self.verify_ts_provenance(output_trees)

@pytest.mark.skipif(
sys.platform == "win32", reason="windows simultaneous file permissions issue"
)
def test_chain(self):
output_trees = os.path.join(self.tempdir.name, "output.trees")
ancestors_trees = os.path.join(self.tempdir.name, "ancestors.trees")
self.run_command(["generate-ancestors", self.sample_file])
self.run_command(["match-ancestors", self.sample_file, "-A", ancestors_trees])
self.verify_ts_provenance(ancestors_trees)
self.run_command(
[
"match-samples",
self.sample_file,
"-A",
ancestors_trees,
"-O",
output_trees,
]
)
self.verify_ts_provenance(output_trees)


class TestMatchSamples(TestCli):
"""
Tests for the match samples options.
Expand Down
28 changes: 21 additions & 7 deletions tsinfer/cli.py
@@ -1,5 +1,5 @@
#
# Copyright (C) 2018 University of Oxford
# Copyright (C) 2018-2022 University of Oxford
#
# This file is part of tsinfer.
#
Expand All @@ -20,6 +20,7 @@
Command line interfaces to tsinfer.
"""
import argparse
import json
import logging
import math
import os.path
Expand All @@ -38,6 +39,7 @@

import tsinfer
import tsinfer.exceptions as exceptions
import tsinfer.provenance as provenance


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,6 +122,21 @@ def run_list(args):
summarise_tree_sequence(ts)


def write_ts(ts, path):
logger.info(f"Writing output tree sequence to {path}")
tables = ts.dump_tables()
# Following guidance at
# https://tskit.dev/tskit/docs/stable/provenance.html#cli-invocations
record = provenance.get_provenance_dict(
command=sys.argv[0],
args=sys.argv[1:],
)
tables.provenances.add_row(json.dumps(record))
# Avoid creating a new TS object by writing tables.
assert tables.has_index()
tables.dump(path)


def run_infer(args):
setup_logging(args)
try:
Expand All @@ -139,8 +156,7 @@ def run_infer(args):
sample_data, progress_monitor=args.progress, num_threads=args.num_threads
)
output_trees = get_output_trees_path(args.output_trees, args.samples)
logger.info(f"Writing output tree sequence to {output_trees}")
ts.dump(output_trees)
write_ts(ts, output_trees)
summarise_usage()


Expand Down Expand Up @@ -172,8 +188,7 @@ def run_match_ancestors(args):
progress_monitor=args.progress,
path_compression=not args.no_path_compression,
)
logger.info(f"Writing ancestors tree sequence to {ancestors_trees}")
ts.dump(ancestors_trees)
write_ts(ts, ancestors_trees)
summarise_usage()


Expand Down Expand Up @@ -221,8 +236,7 @@ def run_match_samples(args):
post_process=not args.no_post_process,
progress_monitor=args.progress,
)
logger.info(f"Writing output tree sequence to {output_trees}")
ts.dump(output_trees)
write_ts(ts, output_trees)
summarise_usage()


Expand Down

0 comments on commit 008f811

Please sign in to comment.