Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
methods to return arrays of per-individual times and populations, respectively.
(:user:`petrelharp`, :issue:`1481`, :pr:`2298`).

- Add the ``sample_mask`` and ``site_mask`` to ``write_vcf`` to allow parts
of an output VCF to be omitted or marked as missing data. Also add the
``as_vcf`` convenience function, to return VCF as a string.
(:user:`jeromekelleher`, :pr:`2300`).

**Breaking Changes**

- The JSON metadata codec now interprets the empty string as an empty object. This means
Expand Down
191 changes: 191 additions & 0 deletions python/tests/test_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
import math
import os
import tempfile
import textwrap

import msprime
import numpy as np
import pytest
import vcf

import tests
import tests.test_wright_fisher as wf
import tskit
from tests import tsutil
Expand Down Expand Up @@ -638,3 +640,192 @@ def test_defaults(self):
assert ts.num_sites > 0
with ts_to_pyvcf(ts) as vcf_reader:
assert vcf_reader.samples == ["tsk_0", "tsk_1"]


def drop_header(s):
return "\n".join(line for line in s.splitlines() if not line.startswith("##"))


class TestMasking:
@tests.cached_example
def ts(self):
ts = tskit.Tree.generate_balanced(3, span=10).tree_sequence
ts = tsutil.insert_branch_sites(ts)
return ts

@pytest.mark.parametrize("mask", [[True], np.zeros(5, dtype=bool), []])
def test_site_mask_wrong_size(self, mask):
with pytest.raises(ValueError, match="Site mask must be"):
self.ts().as_vcf(site_mask=mask)

@pytest.mark.parametrize("mask", [[[0, 1], [1, 0]], "abcd"])
def test_site_mask_bad_type(self, mask):
# converting to a bool array is pretty lax in what's allows.
with pytest.raises(ValueError, match="Site mask must be"):
self.ts().as_vcf(site_mask=mask)

@pytest.mark.parametrize("mask", [[[0, 1], [1, 0]], "abcd"])
def test_sample_mask_bad_type(self, mask):
# converting to a bool array is pretty lax in what's allows.
with pytest.raises(ValueError, match="Sample mask must be"):
self.ts().as_vcf(sample_mask=mask)

def test_no_masks(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
expected = textwrap.dedent(s)
assert drop_header(self.ts().as_vcf()) == expected

def test_no_masks_triploid(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1|0|0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0|1|1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0|1|0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1"""
expected = textwrap.dedent(s)
assert drop_header(self.ts().as_vcf(ploidy=3)) == expected

def test_site_0_masked(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(site_mask=[True, False, False, False])
assert drop_header(actual) == expected

def test_site_0_masked_triploid(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0|1|1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0|1|0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(ploidy=3, site_mask=[True, False, False, False])
assert drop_header(actual) == expected

def test_site_1_masked(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(site_mask=[False, True, False, False])
assert drop_header(actual) == expected

def test_all_sites_masked(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(site_mask=[True, True, True, True])
assert drop_header(actual) == expected

def test_all_sites_not_masked(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(site_mask=[False, False, False, False])
assert drop_header(actual) == expected

@pytest.mark.parametrize(
"mask",
[[False, False, False], [0, 0, 0], lambda _: [False, False, False]],
)
def test_all_samples_not_masked(self, mask):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(sample_mask=mask)
assert drop_header(actual) == expected

@pytest.mark.parametrize(
"mask", [[True, False, False], [1, 0, 0], lambda _: [True, False, False]]
)
def test_sample_0_masked(self, mask):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t0\t0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t.\t1\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t.\t1\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t0\t1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(sample_mask=mask)
assert drop_header(actual) == expected

@pytest.mark.parametrize(
"mask", [[False, True, False], [0, 1, 0], lambda _: [False, True, False]]
)
def test_sample_1_masked(self, mask):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t.\t0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t.\t0
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1"""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(sample_mask=mask)
assert drop_header(actual) == expected

@pytest.mark.parametrize(
"mask", [[True, True, True], [1, 1, 1], lambda _: [True, True, True]]
)
def test_all_samples_masked(self, mask):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t.\t."""
expected = textwrap.dedent(s)
actual = self.ts().as_vcf(sample_mask=mask)
assert drop_header(actual) == expected

def test_all_functional_sample_mask(self):
s = """\
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t0\t0
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t.
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t0\t1"""

def mask(variant):
a = [0, 0, 0]
a[variant.site.id % 3] = 1
return a

expected = textwrap.dedent(s)
actual = self.ts().as_vcf(sample_mask=mask)
assert drop_header(actual) == expected

@pytest.mark.skipif(not _pysam_imported, reason="pysam not available")
def test_mask_ok_with_pysam(self):
with ts_to_pysam(self.ts(), sample_mask=[0, 0, 1]) as records:
variants = list(records)
assert len(variants) == 4
samples = ["tsk_0", "tsk_1", "tsk_2"]
gts = [variants[0].samples[key]["GT"] for key in samples]
assert gts == [(1,), (0,), (None,)]

gts = [variants[1].samples[key]["GT"] for key in samples]
assert gts == [(0,), (1,), (None,)]

gts = [variants[2].samples[key]["GT"] for key in samples]
assert gts == [(0,), (1,), (None,)]

gts = [variants[3].samples[key]["GT"] for key in samples]
assert gts == [(0,), (0,), (None,)]
47 changes: 47 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -5325,6 +5325,18 @@ def samples(self, population=None, *, population_id=None, time=None):
)
return samples[keep]

def as_vcf(self, **kwargs):
"""
Return the result of :meth:`.write_vcf` as a string.
Keyword parameters are as defined in :meth:`.write_vcf`.

:return: A VCF encoding of the variants in this tree sequence as a string.
:rtype: str
"""
buff = io.StringIO()
self.write_vcf(buff, **kwargs)
return buff.getvalue()

def write_vcf(
self,
output,
Expand All @@ -5333,6 +5345,8 @@ def write_vcf(
individuals=None,
individual_names=None,
position_transform=None,
site_mask=None,
sample_mask=None,
):
"""
Writes a VCF formatted file to the specified file-like object.
Expand Down Expand Up @@ -5463,6 +5477,23 @@ def write_vcf(

$ tskit vcf example.trees | bcftools view -O b > example.bcf


The ``sample_mask`` argument provides a general way to mask out
parts of the output, which can be helpful when simulating missing
data. In this (contrived) example, we create a sample mask function
that marks one genotype missing in each variant in a regular
pattern:

.. code-block:: python

def sample_mask(variant):
sample_mask = np.zeros(ts.num_samples, dtype=bool)
sample_mask[variant.site.id % ts.num_samples] = 1
return sample_mask


ts.write_vcf(sys.stdout, sample_mask=sample_mask)

:param io.IOBase output: The file-like object to write the VCF output.
:param int ploidy: The ploidy of the individuals to be written to
VCF. This sample size must be evenly divisible by ploidy. Cannot be
Expand All @@ -5488,6 +5519,20 @@ def write_vcf(
pre 0.2.0 legacy behaviour of rounding values to the nearest integer
(starting from 1) and avoiding the output of identical positions
by incrementing is used.
:param site_mask: A numpy boolean array (or something convertable to
a numpy boolean array) with num_sites elements, used to mask out
sites in the output. If ``site_mask[j]`` is True, then this
site (i.e., the line in the VCF file) will be omitted.
:param sample_mask: A numpy boolean array (or something convertable to
a numpy boolean array) with num_samples elements, or a callable
that returns such an array, such that if
``sample_mask[j]`` is True, then the genotype for sample ``j``
will be marked as missing using a ".". If ``sample_mask`` is a
callable, it must take a single argument and return a boolean
numpy array. This function will be called for each (unmasked) site
with the corresponding :class:`.Variant` object, allowing
for dynamic masks to be generated. See above for example
usage.
"""
writer = vcf.VcfWriter(
self,
Expand All @@ -5496,6 +5541,8 @@ def write_vcf(
individuals=individuals,
individual_names=individual_names,
position_transform=position_transform,
site_mask=site_mask,
sample_mask=sample_mask,
)
writer.write(output)

Expand Down
39 changes: 37 additions & 2 deletions python/tskit/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
individuals=None,
individual_names=None,
position_transform=None,
site_mask=None,
sample_mask=None,
):
self.tree_sequence = tree_sequence
self.contig_id = contig_id
Expand Down Expand Up @@ -98,6 +100,18 @@ def __init__(
# from the legacy VCF output code.
self.contig_length = max(self.transformed_positions[-1], self.contig_length)

if site_mask is None:
site_mask = np.zeros(tree_sequence.num_sites, dtype=bool)
self.site_mask = np.array(site_mask, dtype=bool)
if self.site_mask.shape != (tree_sequence.num_sites,):
raise ValueError("Site mask must be 1D a boolean array of length num_sites")

self.sample_mask = sample_mask
if sample_mask is not None:
if not callable(sample_mask):
sample_mask = np.array(sample_mask, dtype=bool)
self.sample_mask = lambda _: sample_mask

def __make_sample_mapping(self, ploidy):
"""
Compute the sample IDs for each VCF individual and the template for
Expand Down Expand Up @@ -176,6 +190,12 @@ def write(self, output):
indexes = np.array(indexes, dtype=int)

for variant in self.tree_sequence.variants(samples=self.samples):
site_id = variant.site.id
# We check the mask before we do any checks so we can use this as a
# way of skipping problematic sites.
if self.site_mask[site_id]:
continue

if variant.num_alleles > 9:
raise ValueError(
"More than 9 alleles not currently supported. Please open an issue "
Expand All @@ -187,7 +207,6 @@ def write(self, output):
"on GitHub if this limitation affects you."
)
pos = self.transformed_positions[variant.index]
site_id = variant.site.id
ref = variant.alleles[0]
alt = ",".join(variant.alleles[1:]) if len(variant.alleles) > 1 else "."
print(
Expand All @@ -204,7 +223,23 @@ def write(self, output):
end="\t",
file=output,
)
gt_array[indexes] = variant.genotypes + ord("0")
# NOTE: when we support missing data we should be able to
# simply add ``and not variant.has_missing_data`` here.
# Probably OK to take the perf hit in making the missing
# data case go in with the more general sample masking case.
if self.sample_mask is None:
gt_array[indexes] = variant.genotypes + ord("0")
else:
genotypes = variant.genotypes.copy()
sample_mask = np.array(self.sample_mask(variant), dtype=bool)
if sample_mask.shape != genotypes.shape:
raise ValueError(
"Sample mask must be a numpy array of size num_samples"
)
gt_array[indexes] = genotypes + ord("0")
genotypes[sample_mask] = -1
missing = genotypes == -1
gt_array[indexes[missing]] = ord(".")
g_bytes = memoryview(gt_array).tobytes()
g_str = g_bytes.decode()
print(g_str, end="", file=output)