diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 8ec06c6589..f5e57f253d 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -234,6 +234,32 @@ def test_provenances_long_args(self): self.assertEqual(args.tree_sequence, tree_sequence) self.assertEqual(args.human, True) + def test_fasta_default_values(self): + parser = cli.get_tskit_parser() + cmd = "fasta" + tree_sequence = "test.trees" + args = parser.parse_args([cmd, tree_sequence]) + self.assertEqual(args.tree_sequence, tree_sequence) + self.assertEqual(args.wrap, 60) + + def test_fasta_short_args(self): + parser = cli.get_tskit_parser() + cmd = "fasta" + tree_sequence = "test.trees" + args = parser.parse_args([ + cmd, tree_sequence, "-w", "100"]) + self.assertEqual(args.tree_sequence, tree_sequence) + self.assertEqual(args.wrap, 100) + + def test_fasta_long_args(self): + parser = cli.get_tskit_parser() + cmd = "fasta" + tree_sequence = "test.trees" + args = parser.parse_args([ + cmd, tree_sequence, "--wrap", "50"]) + self.assertEqual(args.tree_sequence, tree_sequence) + self.assertEqual(args.wrap, 50) + def test_vcf_default_values(self): parser = cli.get_tskit_parser() cmd = "vcf" @@ -438,6 +464,19 @@ def test_provenances_human(self): # TODO Check the actual output here. self.assertGreater(len(output_provenances), 0) + def verify_fasta(self, output_fasta): + with tempfile.TemporaryFile("w+") as f: + self._tree_sequence.write_fasta(f) + f.seek(0) + fasta = f.read() + self.assertEqual(output_fasta, fasta) + + def test_fasta(self): + cmd = "fasta" + stdout, stderr = capture_output(cli.tskit_main, [cmd, self._tree_sequence_file]) + self.assertEqual(len(stderr), 0) + self.verify_fasta(stdout) + def verify_vcf(self, output_vcf): with tempfile.TemporaryFile("w+") as f: self._tree_sequence.write_vcf(f) @@ -507,6 +546,9 @@ def verify(self, command): def test_info(self): self.verify("info") + def test_fasta(self): + self.verify("fasta") + def test_vcf(self): self.verify("vcf") diff --git a/python/tests/test_fasta.py b/python/tests/test_fasta.py new file mode 100644 index 0000000000..05674ae8cf --- /dev/null +++ b/python/tests/test_fasta.py @@ -0,0 +1,203 @@ +# MIT License +# +# Copyright (c) 2018-2019 Tskit Developers +# Copyright (c) 2016 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for fasta output in tskit. +""" +import os +import tempfile +import unittest +import itertools +import io + +import msprime +from Bio import SeqIO + +from tests import tsutil + + +# setting up some basic haplotype data for tests +def create_data(length): + ts = msprime.simulate(sample_size=10, length=length, mutation_rate=1e-2, + random_seed=123) + ts = tsutil.jukes_cantor(ts, length, 1.0, seed=123) + assert ts.num_sites == length + return ts + + +class TestLineLength(unittest.TestCase): + """ + Tests if the fasta file produced has the correct line lengths for + default, custom, and no-wrapping options. + """ + def verify_line_length(self, length, wrap_width=60): + # set up data + length = length + ts = create_data(length) + output = io.StringIO() + ts.write_fasta(output, wrap_width=wrap_width) + output.seek(0) + + # check if length perfectly divisible by wrap_width or not and thus + # expected line lengths + no_hanging_line = True + if wrap_width == 0: + lines_expect = 1 + # for easier code in testing function, redefine wrap_width as + # full length, ok as called write already + wrap_width = length + elif length % wrap_width == 0: + lines_expect = length//wrap_width + else: + lines_expect = length//wrap_width + 1 + extra_line_length = length % wrap_width + no_hanging_line = False + + seq_line_counter = 0 + id_lines = 0 + for line in output: + # testing correct characters per sequence line + if line[0] != ">": + seq_line_counter += 1 + line_chars = len(line.strip('\n')) + # test full default width lines + if seq_line_counter < lines_expect: + self.assertEqual(wrap_width, line_chars) + elif no_hanging_line: + self.assertEqual(wrap_width, line_chars) + # test extra line if not perfectly divided by wrap_width + else: + self.assertEqual(extra_line_length, line_chars) + # testing correct number of lines per sequence and correct num sequences + else: + id_lines += 1 + if seq_line_counter > 0: + self.assertEqual(lines_expect, seq_line_counter) + seq_line_counter = 0 + self.assertEqual(id_lines, ts.num_samples) + + def test_wrap_length_default_easy(self): + # default wrap width (60) perfectly divides sequence length + self.verify_line_length(length=300) + + def test_wrap_length_default_harder(self): + # default wrap_width imperfectly divides sequence length + self.verify_line_length(length=280) + + def test_wrap_length_custom_easy(self): + # custom wrap_width, perfectly divides + self.verify_line_length(length=100, wrap_width=20) + + def test_wrap_length_custom_harder(self): + # custom wrap_width, imperfectly divides + self.verify_line_length(length=100, wrap_width=30) + + def test_wrap_length_no_wrap(self): + # no wrapping set by wrap_width = 0 + self.verify_line_length(length=100, wrap_width=0) + + def test_bad_wrap(self): + ts = create_data(100) + with self.assertRaises(ValueError): + ts.write_fasta(io.StringIO(), wrap_width=-1) + + +class TestSequenceIds(unittest.TestCase): + """ + Tests that sequence IDs are output correctly, whether default or custom + and that the length of IDs supplied must equal number of sequences + """ + def verify_ids(self, ts, seq_ids_in=None): + seq_ids_read = [] + with tempfile.TemporaryDirectory() as temp_dir: + fasta_path = os.path.join(temp_dir, "testing_def_fasta.txt") + with open(fasta_path, "w") as f: + ts.write_fasta(f, sequence_ids=seq_ids_in) + with open(fasta_path, "r") as handle: + for record in SeqIO.parse(handle, "fasta"): + seq_ids_read.append(record.id) + + # test default seq ids + if seq_ids_in in [None]: + for i, val in enumerate(seq_ids_read): + self.assertEqual("tsk_{}".format(i), val) + # test custom seq ids + else: + for i, j in itertools.zip_longest(seq_ids_in, seq_ids_read): + self.assertEqual(i, j) + + def test_default_ids(self): + # test that default sequence ids, immediately following '>', are as expected + ts = create_data(100) + self.verify_ids(ts) + + def test_custom_ids(self): + # test that custom sequence ids, immediately following '>', are as expected + ts = create_data(100) + seq_ids_in = ["x_{}".format(_) for _ in range(ts.num_samples)] + self.verify_ids(ts, seq_ids_in) + + def test_bad_length_ids(self): + ts = create_data(100) + with self.assertRaises(ValueError): + seq_ids_in = ["x_{}".format(_) for _ in range(ts.num_samples-1)] + ts.write_fasta(io.StringIO(), sequence_ids=seq_ids_in) + with self.assertRaises(ValueError): + seq_ids_in = ["x_{}".format(_) for _ in range(ts.num_samples+1)] + ts.write_fasta(io.StringIO(), sequence_ids=seq_ids_in) + with self.assertRaises(ValueError): + seq_ids_in = [] + ts.write_fasta(io.StringIO(), sequence_ids=seq_ids_in) + + +class TestRoundTrip(unittest.TestCase): + """ + Tests that output from our code is read in by available software packages + Here test for compatability with biopython processing - Bio.SeqIO + """ + def verify(self, ts, wrap_width=60): + biopython_fasta_read = [] + with tempfile.TemporaryDirectory() as temp_dir: + fasta_path = os.path.join(temp_dir, "testing_def_fasta.txt") + with open(fasta_path, "w") as f: + ts.write_fasta(f, wrap_width=wrap_width) + with open(fasta_path, "r") as handle: + for record in SeqIO.parse(handle, "fasta"): + biopython_fasta_read.append(record.seq) + + for i, j in itertools.zip_longest(biopython_fasta_read, ts.haplotypes()): + self.assertEqual(i, j) + + def test_equal_lines(self): + # sequence length perfectly divisible by wrap_width + ts = create_data(300) + self.verify(ts) + + def test_unequal_lines(self): + # sequence length not perfectly divisible by wrap_width + ts = create_data(280) + self.verify(ts) + + def test_unwrapped(self): + # sequences not wrapped + ts = create_data(300) + self.verify(ts, wrap_width=0) diff --git a/python/tskit/cli.py b/python/tskit/cli.py index ac53c0c19e..de30ce1a25 100644 --- a/python/tskit/cli.py +++ b/python/tskit/cli.py @@ -129,6 +129,11 @@ def run_provenances(args): tree_sequence.dump_text(provenances=sys.stdout) +def run_fasta(args): + tree_sequence = load_tree_sequence(args.tree_sequence) + tree_sequence.write_fasta(sys.stdout, wrap_width=args.wrap) + + def run_vcf(args): tree_sequence = load_tree_sequence(args.tree_sequence) tree_sequence.write_vcf(sys.stdout, ploidy=args.ploidy) @@ -182,6 +187,15 @@ def get_tskit_parser(): help="Remove any duplicated mutation positions in the source file. ") parser.set_defaults(runner=run_upgrade) + parser = subparsers.add_parser( + "fasta", + help="Convert the tree sequence haplotypes to fasta format") + add_tree_sequence_argument(parser) + parser.add_argument( + "--wrap", "-w", type=int, default=60, + help=("line-wrapping width for printed sequences")) + parser.set_defaults(runner=run_fasta) + parser = subparsers.add_parser( "vcf", help="Convert the tree sequence genotypes to VCF format.") diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 23ac4439d1..436f6391db 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -31,6 +31,7 @@ import warnings import functools import concurrent.futures +import textwrap import numpy as np @@ -2917,6 +2918,68 @@ def samples(self, population=None, population_id=None): samples = samples[sample_population == population] return samples + def write_fasta(self, output, sequence_ids=None, wrap_width=60): + """ + Writes haplotype data for samples in FASTA format to the + specified file-like object. + + Default `sequence_ids` (i.e. the text immediately following ">") are + "tsk_{sample_number}" e.g. "tsk_0", "tsk_1" etc. They can be set by providing + a list of strings to the `sequence_ids` argument, which must equal the length + of the number of samples. Please ensure that these are unique and compatible with + fasta standards, since we do not check this. + Default `wrap_width` for sequences is 60 characters in accordance with fasta + standard outputs, but this can be specified. In order to avoid any line-wrapping + of sequences, set `wrap_width = 0`. + + Example usage: + + .. code-block:: python + with open("output.fasta", "w") as fasta_file: + ts.write_fasta(fasta_file) + + This can also be achieved on the command line use the ``tskit fasta`` command, + e.g.: + + .. code-block:: bash + + $ tskit fasta example.trees > example.fasta + + :param File output: The file-like object to write the fasta output. + :param list(str) sequence_ids: A list of string names to uniquely identify + each of the sequences in the fasta file. If specified, this must be a + list of strings of length equal to the number of samples which are output. + Note that we do not check the form of these strings in any way, so that it + is possible to output bad fasta IDs (for example, by including spaces + before the unique identifying part of the string). + The default is to output ``tsk_j`` for the jth individual. + :param int wrap_width: This parameter specifies the number of sequence + characters to include on each line in the fasta file, before wrapping + to the next line for each sequence. Defaults to 60 characters in + accordance with fasta standard outputs. To avoid any line-wrapping of + sequences, set `wrap_width = 0`. Otherwise, supply any positive integer. + """ + # if not specified, IDs default to sample index + if sequence_ids is None: + sequence_ids = ["tsk_{}".format(j) for j in self.samples()] + if len(sequence_ids) != self.num_samples: + raise ValueError( + "sequence_ids must have length equal to the number of samples.") + + wrap_width = int(wrap_width) + if wrap_width < 0: + raise ValueError("wrap_width must be a non-negative integer. " + "You may specify `wrap_width=0` " + "if you do not want any wrapping.") + + for j, hap in enumerate(self.haplotypes()): + print(">", sequence_ids[j], sep="", file=output) + if wrap_width == 0: + print(hap, file=output) + else: + for hap_wrap in textwrap.wrap(hap, wrap_width): + print(hap_wrap, file=output) + def write_vcf( self, output, ploidy=None, contig_id="1", individuals=None, individual_names=None, position_transform=None):