In [19]:
import math
import os
import random
import sys
import time

import gzip

import numpy as np

import msprime
import tskit
import tsinfer
from tsinfer import make_ancestors_ts

import cyvcf2
import demes
import demesdraw

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"demes {demes.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.5.0
tsinfer 0.2.3
msprime 1.1.1
demes 0.2.1
cyvcf2 0.30.14


### Simulation parameters

In [38]:
replicate_index = 0

sampling_time_query = 100

proportion_missing_sites = 0.80

### Simulate genealogy and genetic variation

In [20]:
size_query = 100
size_ref   = 1_000

eff_pop_size = 10_000
mutation_rate = 1e-8
recombination_rate = 1e-8

contig_id = '1'
ploidy_level = 1
sequence_length = 1_000_000

In [21]:
rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=recombination_rate,
)

In [22]:
sample_set = [
    msprime.SampleSet(num_samples=size_query,
                      time=sampling_time_query,
                      ploidy=ploidy_level),
    msprime.SampleSet(num_samples=size_ref,
                      time=0,
                      ploidy=ploidy_level),
]

In [23]:
ts_full = msprime.sim_mutations(
    msprime.sim_ancestry(
        samples=sample_set,
        population_size=eff_pop_size,
        model="hudson",
        recombination_rate=rate_map,
        discrete_genome=True,
    ),
    rate=mutation_rate,
    discrete_genome=True,
)

In [24]:
individuals_query = np.arange(size_query, dtype=int)
individual_names_query = ["query_" + str(i) for i in individuals_query]
samples_query = np.arange(ploidy_level * size_query, dtype=int)

In [25]:
individuals_ref = np.arange(size_query, size_query + size_ref, dtype=int)
individual_names_ref = ["ref_" + str(i) for i in individuals_ref]
samples_ref = np.arange(ploidy_level * size_query, ploidy_level * (size_query + size_ref), dtype=int)

### Match query genomes to an ancestor tree sequence

In [26]:
def create_sample_data_from_tree_sequence(ts, skip_multiallelic_sites=False):
    """
    Match samples in a SampleData object to an ancestors ts,
    without the need to read from a VCF file.
    
    Optionally, skip multi-allelic sites.
    """
    with tsinfer.SampleData(ts.sequence_length) as sample_data:
        for variant in ts.variants():
            # Skip multi-allelic sites
            if skip_multiallelic_sites and len(variant.alleles) > 2:
                continue
            sample_data.add_site(
                position=variant.site.position,
                genotypes=variant.genotypes,
                alleles=variant.alleles,
            )
    return(sample_data)

In [27]:
sd_full = create_sample_data_from_tree_sequence(ts_full, skip_multiallelic_sites=True)

sd_query = sd_full.subset(individuals=individuals_query) # Used for matching

In [28]:
ts_ref = ts_full.simplify(samples_ref, filter_sites=False) # Remove private branches

ts_anc = make_ancestors_ts(ts=ts_ref, remove_leaves=True, samples=None)

# Identify sites unique to ts_anc but not in sd_query
sites_ts_anc = set(ts_anc.tables.sites.position)
sites_sd_query = set(list(sd_query.sites_position))
site_positions_to_remove = list(sites_ts_anc.difference(sites_sd_query))
site_ids_to_remove = [ts_anc.site(position=p).id for p in site_positions_to_remove]

tmp_tables = ts_anc.dump_tables()
tmp_tables.delete_sites(site_ids_to_remove) # Remove sites unique to ts_anc
tmp_tables.individuals.clear()
tmp_tables.populations.metadata_schema = tskit.MetadataSchema(schema=None)

ts_anc = tmp_tables.tree_sequence() # Used for matching

In [29]:
ts_matched = tsinfer.match_samples(sample_data=sd_query, ancestors_ts=ts_anc)

### Impute into query genomes from the ancestor tree sequence

In [30]:
def mask_sites_in_sample_data(sd, sequence_length, sites_to_mask=[]):
    """
    Create a SampleData object from an existing SampleData object,
    while masking out specified sites listed in sites_to_mask.
    """
    with tsinfer.SampleData(sequence_length) as sample_data:
        for i, variant in enumerate(sd.variants()):
            if i in sites_to_mask:
                sample_data.add_site(
                    position=variant.site.position,
                    genotypes=np.repeat(tskit.MISSING_DATA, len(variant.genotypes)),
                    alleles=variant.alleles, # Keep the same
                )
            else:
                sample_data.add_site(
                    position=variant.site.position,
                    genotypes=variant.genotypes,
                    alleles=variant.alleles,
                )
    return(sample_data)

In [31]:
masked_sites = random.sample(
    [i for i in range(sd_query.num_sites)],
    int(np.floor(sd_query.num_sites * proportion_missing_sites))
)

In [32]:
sd_query_masked = mask_sites_in_sample_data(
    sd_query,
    sequence_length=sequence_length,
    sites_to_mask=masked_sites,
)

In [33]:
ts_imputed = tsinfer.match_samples(
    sample_data=sd_query_masked,
    ancestors_ts=ts_anc
)

### Write results to VCF

In [42]:
true_vcf_file = ".".join(["true", str(replicate_index), "vcf", "gz"])
imputed_vcf_file = ".".join(["imputed", str(replicate_index), "vcf", "gz"])

In [43]:
with gzip.open(true_vcf_file, "wt") as f:
    ts_matched.write_vcf(f)

In [44]:
with gzip.open(imputed_vcf_file, "wt") as f:
    ts_imputed.write_vcf(f)