In [1]:
import math
import os
import random
import sys
import time

import gzip

import numpy as np

import msprime
import tskit
import tsinfer
from tsinfer import make_ancestors_ts

import cyvcf2
import demes
import demesdraw

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"demes {demes.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.5.0
tsinfer 0.2.3
msprime 1.1.1
demes 0.2.1
cyvcf2 0.30.14


### Simulation parameters

In [2]:
replicate_index = 0

sampling_time_query = 100

proportion_missing_sites = 0.80

### Simulate genealogy and genetic variation

In [3]:
size_query = 100
size_ref   = 1_000

eff_pop_size = 10_000
mutation_rate = 1e-8
recombination_rate = 1e-8

contig_id = '1'
ploidy_level = 1 # Haploid genomes
sequence_length = 1_000_000 # 1 Mbp

In [4]:
# Uniform recombination rate
recomb_rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=recombination_rate,
)

# Uniform mutation rate
mut_rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=mutation_rate,
)

In [5]:
sample_set = [
    # Query genomes
    msprime.SampleSet(num_samples=size_query,
                      time=sampling_time_query,
                      ploidy=ploidy_level),
    # Reference panel
    msprime.SampleSet(num_samples=size_ref,
                      time=0,
                      ploidy=ploidy_level),
]

In [6]:
# A simulated tree sequence does not contain any monoallelic sites,
# but there may be multiallelic sites.
ts_full = msprime.sim_mutations(
    msprime.sim_ancestry(
        samples=sample_set,
        population_size=eff_pop_size,
        model="hudson",
        recombination_rate=recomb_rate_map,
        discrete_genome=True,
    ),
    rate=mut_rate_map,
    discrete_genome=True,
)

In [7]:
individuals_query = np.arange(size_query, dtype=int)
individual_names_query = ["query_" + str(i) for i in individuals_query]

samples_query = np.arange(ploidy_level * size_query, dtype=int)

In [8]:
individuals_ref = np.arange(size_query, size_query + size_ref, dtype=int)
individual_names_ref = ["ref_" + str(i) for i in individuals_ref]

samples_ref = np.arange(ploidy_level * size_query, ploidy_level * (size_query + size_ref), dtype=int)

### Match query genomes to an ancestor tree sequence

In [15]:
def create_sample_data_from_tree_sequence(ts, skip_multiallelic_sites=False):
    """
    Match samples in a SampleData object to an ancestors ts,
    without the need to read from a VCF file.
    
    Optionally, skip multi-allelic sites.
    """
    with tsinfer.SampleData(ts.sequence_length) as sample_data:
        for variant in ts.variants():
            if skip_multiallelic_sites and len(variant.alleles) > 2:
                continue
            sample_data.add_site(
                position=variant.site.position,
                genotypes=variant.genotypes,
                alleles=variant.alleles,
            )
    return(sample_data)

In [16]:
# This step excludes sites that are multiallelic across the whole data set,
# i.e. both the reference panel and query set combined.
# It may exclude some sites that are biallelic in either the reference panel or query set.
sd_full = create_sample_data_from_tree_sequence(ts_full, skip_multiallelic_sites=True)

In [17]:
# Multiallelic sites are excluded.
sd_query = sd_full.subset(individuals=individuals_query) # Used for matching

In [22]:
# Multiallelic sites are included.
ts_ref = ts_full.simplify(samples_ref, filter_sites=False) # Remove private branches

In [45]:
# Multiallelic sites are automatically removed when generating an ancestor tree sequence.
ts_anc = make_ancestors_ts(ts=ts_ref, remove_leaves=True, samples=None)

# Identify sites unique to ts_anc but not in sd_query
sites_ts_anc = set(ts_anc.tables.sites.position)
sites_sd_query = set(list(sd_query.sites_position))
site_positions_to_remove = list(sites_ts_anc.difference(sites_sd_query))
site_ids_to_remove = [ts_anc.site(position=p).id for p in site_positions_to_remove]

tmp_tables = ts_anc.dump_tables()
tmp_tables.delete_sites(site_ids_to_remove) # Remove sites unique to ts_anc
tmp_tables.individuals.clear()
tmp_tables.populations.metadata_schema = tskit.MetadataSchema(schema=None)

ts_anc = tmp_tables.tree_sequence() # Used for matching

In [46]:
ts_matched = tsinfer.match_samples(sample_data=sd_query, ancestors_ts=ts_anc)

In [47]:
print(f"SD whole   : {sd_full.num_sites}")
print(f"SD query   : {sd_query.num_sites}")
print(f"TS ref     : {ts_ref.num_sites}")
print(f"TS anc     : {ts_anc.num_sites}")
print(f"TS matched : {ts_matched.num_sites}")

SD whole   : 3054
SD query   : 3054
TS ref     : 3059
TS anc     : 2549
TS matched : 3054


### Impute into query genomes from the ancestor tree sequence

In [85]:
def sample_sites(sample_data, proportion_missing_sites):
    """
    Draw N sites stored in a SampleData object at random
    by site id, where N is the number of sites to mask
    based on a specified proportion of missing sites.
    
    Return a boolean array of size sample_data.num_sites,
    where True indicates that the site is masked.
    
    TODO: Specify random seed.
    """
    assert proportion_missing_sites >= 0 and proportion_missing_sites <= 1, \
        "Proportion of missing sites must be between 0 and 1."
    
    num_masked_sites = int(np.floor(sample_data.num_sites * proportion_missing_sites))
    
    masked_site_ids = np.sort(
        np.random.default_rng().integers(
            low=0,
            high=sample_data.num_sites,
            size=num_masked_sites,
        )
    )
    
    masked_site_bool = np.zeros(sample_data.num_sites, dtype=bool)
    masked_site_bool[masked_site_ids] = True
    
    return(masked_site_bool)

In [86]:
def mask_sites_in_sample_data(sample_data, sites_to_mask=None):
    """
    Create a SampleData object from an existing SampleData object,
    while masking specified sites listed in sites_to_mask (site ids).
    """
    if sites_to_mask is None:
        # By default, no site is masked
        sites_to_mask = np.zeros(sample_data.num_sites, dtype=bool)
    
    num_variants = 0 # Check the number of sites and number of variants are identical
    with tsinfer.SampleData(sample_data.sequence_length) as sample_data_new:
        for i, variant in enumerate(sample_data.variants()):
            num_variants += 1
            if sites_to_mask[i]:
                genotypes = np.repeat(tskit.MISSING_DATA, len(variant.genotypes))
            else:
                genotypes = variant.genotypes
            sample_data_new.add_site(
                position=variant.site.position,
                genotypes=genotypes,
                alleles=variant.alleles, # Keep the same
            )
    
    assert sample_data.num_sites == num_variants,\
        "Number of sites is not equal to the number of variants."
    
    return(sample_data_new)

In [87]:
masked_site_bool = sample_sites(sd_query, proportion_missing_sites)
sd_query_masked = mask_sites_in_sample_data(sd_query, sites_to_mask=masked_site_bool)

In [52]:
ts_imputed = tsinfer.match_samples(
    sample_data=sd_query_masked,
    ancestors_ts=ts_anc,
)

In [53]:
print(f"SD whole   : {sd_full.num_sites}")
print(f"SD query   : {sd_query.num_sites}")
print(f"TS ref     : {ts_ref.num_sites}")
print(f"TS anc     : {ts_anc.num_sites}")
print(f"TS imputed : {ts_imputed.num_sites}")

SD whole   : 3054
SD query   : 3054
TS ref     : 3059
TS anc     : 2549
TS imputed : 3054


### Write results to VCF

In [54]:
true_vcf_file = ".".join(["true", str(replicate_index), "vcf", "gz"])
imputed_vcf_file = ".".join(["imputed", str(replicate_index), "vcf", "gz"])

In [55]:
with gzip.open(true_vcf_file, "wt") as f:
    ts_matched.write_vcf(f)

In [56]:
with gzip.open(imputed_vcf_file, "wt") as f:
    ts_imputed.write_vcf(f)

### Prepare VCF for BEAGLE and GLIMPSE

In [57]:
ts_ref = ts_full.simplify(samples_ref, filter_sites=False)
ts_query = ts_full.simplify(samples_query, filter_sites=False)

In [58]:
ref_vcf_file = ".".join(["ref", str(replicate_index), "vcf", "gz"])
query_true_vcf_file = ".".join(["query_true", str(replicate_index), "vcf", "gz"])
query_miss_vcf_file = ".".join(["query_miss", str(replicate_index), "vcf", "gz"])

In [59]:
with gzip.open(ref_vcf_file, "wt") as f:
    ts_ref.write_vcf(f)

In [60]:
with gzip.open(query_true_vcf_file, "wt") as f:
    ts_query.write_vcf(f)

In [61]:
with gzip.open(query_miss_vcf_file, "wt") as f:
    ts_query.write_vcf(
        f,
        sample_mask=masked_site_bool, # Mark sites as *
    )

ValueError: Sample mask must be a numpy array of size num_samples