diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 931dd584ab..e2b99dcd1a 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -34,6 +34,11 @@ methods to return arrays of per-individual times and populations, respectively. (:user:`petrelharp`, :issue:`1481`, :pr:`2298`). +- Add the ``sample_mask`` and ``site_mask`` to ``write_vcf`` to allow parts + of an output VCF to be omitted or marked as missing data. Also add the + ``as_vcf`` convenience function, to return VCF as a string. + (:user:`jeromekelleher`, :pr:`2300`). + **Breaking Changes** - The JSON metadata codec now interprets the empty string as an empty object. This means diff --git a/python/tests/test_vcf.py b/python/tests/test_vcf.py index 3d2dc2bc2f..655be6acd2 100644 --- a/python/tests/test_vcf.py +++ b/python/tests/test_vcf.py @@ -29,12 +29,14 @@ import math import os import tempfile +import textwrap import msprime import numpy as np import pytest import vcf +import tests import tests.test_wright_fisher as wf import tskit from tests import tsutil @@ -638,3 +640,192 @@ def test_defaults(self): assert ts.num_sites > 0 with ts_to_pyvcf(ts) as vcf_reader: assert vcf_reader.samples == ["tsk_0", "tsk_1"] + + +def drop_header(s): + return "\n".join(line for line in s.splitlines() if not line.startswith("##")) + + +class TestMasking: + @tests.cached_example + def ts(self): + ts = tskit.Tree.generate_balanced(3, span=10).tree_sequence + ts = tsutil.insert_branch_sites(ts) + return ts + + @pytest.mark.parametrize("mask", [[True], np.zeros(5, dtype=bool), []]) + def test_site_mask_wrong_size(self, mask): + with pytest.raises(ValueError, match="Site mask must be"): + self.ts().as_vcf(site_mask=mask) + + @pytest.mark.parametrize("mask", [[[0, 1], [1, 0]], "abcd"]) + def test_site_mask_bad_type(self, mask): + # converting to a bool array is pretty lax in what's allows. + with pytest.raises(ValueError, match="Site mask must be"): + self.ts().as_vcf(site_mask=mask) + + @pytest.mark.parametrize("mask", [[[0, 1], [1, 0]], "abcd"]) + def test_sample_mask_bad_type(self, mask): + # converting to a bool array is pretty lax in what's allows. + with pytest.raises(ValueError, match="Sample mask must be"): + self.ts().as_vcf(sample_mask=mask) + + def test_no_masks(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1""" + expected = textwrap.dedent(s) + assert drop_header(self.ts().as_vcf()) == expected + + def test_no_masks_triploid(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1|0|0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0|1|1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0|1|0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1""" + expected = textwrap.dedent(s) + assert drop_header(self.ts().as_vcf(ploidy=3)) == expected + + def test_site_0_masked(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(site_mask=[True, False, False, False]) + assert drop_header(actual) == expected + + def test_site_0_masked_triploid(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0|1|1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0|1|0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(ploidy=3, site_mask=[True, False, False, False]) + assert drop_header(actual) == expected + + def test_site_1_masked(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(site_mask=[False, True, False, False]) + assert drop_header(actual) == expected + + def test_all_sites_masked(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(site_mask=[True, True, True, True]) + assert drop_header(actual) == expected + + def test_all_sites_not_masked(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(site_mask=[False, False, False, False]) + assert drop_header(actual) == expected + + @pytest.mark.parametrize( + "mask", + [[False, False, False], [0, 0, 0], lambda _: [False, False, False]], + ) + def test_all_samples_not_masked(self, mask): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(sample_mask=mask) + assert drop_header(actual) == expected + + @pytest.mark.parametrize( + "mask", [[True, False, False], [1, 0, 0], lambda _: [True, False, False]] + ) + def test_sample_0_masked(self, mask): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t0\t0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t.\t1\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t.\t1\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t0\t1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(sample_mask=mask) + assert drop_header(actual) == expected + + @pytest.mark.parametrize( + "mask", [[False, True, False], [0, 1, 0], lambda _: [False, True, False]] + ) + def test_sample_1_masked(self, mask): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t.\t0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t.\t0 + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(sample_mask=mask) + assert drop_header(actual) == expected + + @pytest.mark.parametrize( + "mask", [[True, True, True], [1, 1, 1], lambda _: [True, True, True]] + ) + def test_all_samples_masked(self, mask): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t.\t. + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t.\t.\t. + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t.\t.\t. + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.""" + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(sample_mask=mask) + assert drop_header(actual) == expected + + def test_all_functional_sample_mask(self): + s = """\ + #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2 + 1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t0\t0 + 1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1 + 1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t. + 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t0\t1""" + + def mask(variant): + a = [0, 0, 0] + a[variant.site.id % 3] = 1 + return a + + expected = textwrap.dedent(s) + actual = self.ts().as_vcf(sample_mask=mask) + assert drop_header(actual) == expected + + @pytest.mark.skipif(not _pysam_imported, reason="pysam not available") + def test_mask_ok_with_pysam(self): + with ts_to_pysam(self.ts(), sample_mask=[0, 0, 1]) as records: + variants = list(records) + assert len(variants) == 4 + samples = ["tsk_0", "tsk_1", "tsk_2"] + gts = [variants[0].samples[key]["GT"] for key in samples] + assert gts == [(1,), (0,), (None,)] + + gts = [variants[1].samples[key]["GT"] for key in samples] + assert gts == [(0,), (1,), (None,)] + + gts = [variants[2].samples[key]["GT"] for key in samples] + assert gts == [(0,), (1,), (None,)] + + gts = [variants[3].samples[key]["GT"] for key in samples] + assert gts == [(0,), (0,), (None,)] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 1e98884054..6da05b3a92 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5325,6 +5325,18 @@ def samples(self, population=None, *, population_id=None, time=None): ) return samples[keep] + def as_vcf(self, **kwargs): + """ + Return the result of :meth:`.write_vcf` as a string. + Keyword parameters are as defined in :meth:`.write_vcf`. + + :return: A VCF encoding of the variants in this tree sequence as a string. + :rtype: str + """ + buff = io.StringIO() + self.write_vcf(buff, **kwargs) + return buff.getvalue() + def write_vcf( self, output, @@ -5333,6 +5345,8 @@ def write_vcf( individuals=None, individual_names=None, position_transform=None, + site_mask=None, + sample_mask=None, ): """ Writes a VCF formatted file to the specified file-like object. @@ -5463,6 +5477,23 @@ def write_vcf( $ tskit vcf example.trees | bcftools view -O b > example.bcf + + The ``sample_mask`` argument provides a general way to mask out + parts of the output, which can be helpful when simulating missing + data. In this (contrived) example, we create a sample mask function + that marks one genotype missing in each variant in a regular + pattern: + + .. code-block:: python + + def sample_mask(variant): + sample_mask = np.zeros(ts.num_samples, dtype=bool) + sample_mask[variant.site.id % ts.num_samples] = 1 + return sample_mask + + + ts.write_vcf(sys.stdout, sample_mask=sample_mask) + :param io.IOBase output: The file-like object to write the VCF output. :param int ploidy: The ploidy of the individuals to be written to VCF. This sample size must be evenly divisible by ploidy. Cannot be @@ -5488,6 +5519,20 @@ def write_vcf( pre 0.2.0 legacy behaviour of rounding values to the nearest integer (starting from 1) and avoiding the output of identical positions by incrementing is used. + :param site_mask: A numpy boolean array (or something convertable to + a numpy boolean array) with num_sites elements, used to mask out + sites in the output. If ``site_mask[j]`` is True, then this + site (i.e., the line in the VCF file) will be omitted. + :param sample_mask: A numpy boolean array (or something convertable to + a numpy boolean array) with num_samples elements, or a callable + that returns such an array, such that if + ``sample_mask[j]`` is True, then the genotype for sample ``j`` + will be marked as missing using a ".". If ``sample_mask`` is a + callable, it must take a single argument and return a boolean + numpy array. This function will be called for each (unmasked) site + with the corresponding :class:`.Variant` object, allowing + for dynamic masks to be generated. See above for example + usage. """ writer = vcf.VcfWriter( self, @@ -5496,6 +5541,8 @@ def write_vcf( individuals=individuals, individual_names=individual_names, position_transform=position_transform, + site_mask=site_mask, + sample_mask=sample_mask, ) writer.write(output) diff --git a/python/tskit/vcf.py b/python/tskit/vcf.py index 290fac2458..969dfdcb4b 100644 --- a/python/tskit/vcf.py +++ b/python/tskit/vcf.py @@ -58,6 +58,8 @@ def __init__( individuals=None, individual_names=None, position_transform=None, + site_mask=None, + sample_mask=None, ): self.tree_sequence = tree_sequence self.contig_id = contig_id @@ -98,6 +100,18 @@ def __init__( # from the legacy VCF output code. self.contig_length = max(self.transformed_positions[-1], self.contig_length) + if site_mask is None: + site_mask = np.zeros(tree_sequence.num_sites, dtype=bool) + self.site_mask = np.array(site_mask, dtype=bool) + if self.site_mask.shape != (tree_sequence.num_sites,): + raise ValueError("Site mask must be 1D a boolean array of length num_sites") + + self.sample_mask = sample_mask + if sample_mask is not None: + if not callable(sample_mask): + sample_mask = np.array(sample_mask, dtype=bool) + self.sample_mask = lambda _: sample_mask + def __make_sample_mapping(self, ploidy): """ Compute the sample IDs for each VCF individual and the template for @@ -176,6 +190,12 @@ def write(self, output): indexes = np.array(indexes, dtype=int) for variant in self.tree_sequence.variants(samples=self.samples): + site_id = variant.site.id + # We check the mask before we do any checks so we can use this as a + # way of skipping problematic sites. + if self.site_mask[site_id]: + continue + if variant.num_alleles > 9: raise ValueError( "More than 9 alleles not currently supported. Please open an issue " @@ -187,7 +207,6 @@ def write(self, output): "on GitHub if this limitation affects you." ) pos = self.transformed_positions[variant.index] - site_id = variant.site.id ref = variant.alleles[0] alt = ",".join(variant.alleles[1:]) if len(variant.alleles) > 1 else "." print( @@ -204,7 +223,23 @@ def write(self, output): end="\t", file=output, ) - gt_array[indexes] = variant.genotypes + ord("0") + # NOTE: when we support missing data we should be able to + # simply add ``and not variant.has_missing_data`` here. + # Probably OK to take the perf hit in making the missing + # data case go in with the more general sample masking case. + if self.sample_mask is None: + gt_array[indexes] = variant.genotypes + ord("0") + else: + genotypes = variant.genotypes.copy() + sample_mask = np.array(self.sample_mask(variant), dtype=bool) + if sample_mask.shape != genotypes.shape: + raise ValueError( + "Sample mask must be a numpy array of size num_samples" + ) + gt_array[indexes] = genotypes + ord("0") + genotypes[sample_mask] = -1 + missing = genotypes == -1 + gt_array[indexes[missing]] = ord(".") g_bytes = memoryview(gt_array).tobytes() g_str = g_bytes.decode() print(g_str, end="", file=output)